def test_immutable_augassign(self): class Test(): def forward(self): x = (1, 2, 3) y = x x += (4,) return y id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual(str(id2type[1]), "class Test -> (int, int, int)") # FunctionDef forward (line 1) self.assertEqual(str(id2type[5]), "NoneType") # Assign (line 2) self.assertEqual(str(id2type[6]), "(int, int, int)") # Name x (line 2) self.assertEqual(str(id2type[8]), "(int, int, int)") # Tuple (line 2) self.assertEqual(str(id2type[9]), "int") # Num (line 2) self.assertEqual(str(id2type[10]), "int") # Num (line 2) self.assertEqual(str(id2type[11]), "int") # Num (line 2) self.assertEqual(str(id2type[13]), "NoneType") # Assign (line 3) self.assertEqual(str(id2type[14]), "(int, int, int)") # Name y (line 3) self.assertEqual(str(id2type[16]), "(int, int, int)") # Name x (line 3) self.assertEqual(str(id2type[18]), "NoneType") # AugAssign (line 4) self.assertEqual(str(id2type[19]), "(int, int, int, int)") # Name x (line 4) self.assertEqual(str(id2type[21]), "(int, int, int) -> (int,) -> (int, int, int, int)") # Add self.assertEqual(str(id2type[22]), "(int,)") # Tuple (line 4) self.assertEqual(str(id2type[23]), "int") # Num (line 4) self.assertEqual(str(id2type[25]), "(int, int, int)") # Return (line 5) self.assertEqual(str(id2type[26]), "(int, int, int)") # Name y (line 5)
def test_DCGAN_Discriminator(self): type_inference_tools.reset_state() model, forward_args = gen_DCGAN_Discriminator_test() id2type = generate_id2type_from_forward(model, forward_args) self.assertEqual(str( id2type[1] ), "class Discriminator -> torch.Tensor(float32, (5, 3, 64, 64)) -> torch.Tensor(float32, (5,))" ) # FunctionDef forward (line 1) self.assertEqual(str(id2type[7]), "NoneType") # Assign self.assertEqual(str(id2type[10]), "torch.Tensor(float32, (5, 1, 1, 1))" ) # Call self.main(input) (line 6) self.assertEqual(str(id2type[12]), "class Discriminator") # Name self (line 6) self.assertEqual( str(id2type[15]), "torch.Tensor(float32, (5, 3, 64, 64))") # Name input (line 6) self.assertEqual(str(id2type[17]), "torch.Tensor(float32, (5,))") # Return self.assertEqual(str(id2type[18]), "torch.Tensor(float32, (5,))" ) # Call output.view(-1, 1).squeeze(dim=1) (line 8) self.assertEqual(str(id2type[20]), "torch.Tensor(float32, (5, 1))" ) # Call output.view(-1, 1) (line 8) self.assertEqual( str(id2type[22]), "torch.Tensor(float32, (5, 1, 1, 1))") # Name output (line 8) self.assertEqual(str(id2type[25]), "int") # UnaryOp -1 (line 8) self.assertEqual(str(id2type[27]), "int") # Constant 1 (line 8) self.assertEqual(str(id2type[28]), "int") # Constant 1 (line 8) self.assertEqual(str(id2type[31]), "int") # Constant 1 (line 8)
def test_lazy_attribute_init(self): class Test(): def __init__(self): self.y = None def forward(self): if self.y is None: self.y = 42 return self.y id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual(str(id2type[1]), "class Test -> int") # FunctionDef forward (line 1) self.assertEqual(str(id2type[5]), "NoneType") # If self.assertEqual(str(id2type[6]), "bool") # Compare (line 2) self.assertEqual(str(id2type[7]), "NoneType") # Attribute self.y (line 2) self.assertEqual(str(id2type[8]), "class Test") # Name self (line 2) self.assertEqual(str(id2type[13]), "NoneType") # Assign self.assertEqual(str(id2type[14]), "int") # Attribute self.y (line 3) self.assertEqual(str(id2type[15]), "class Test") # Name self (line 3) self.assertEqual(str(id2type[18]), "int") # Num 42 (line 3) self.assertEqual(str(id2type[19]), "int") # Return self.assertEqual(str(id2type[20]), "int") # Attribute self.y (line 4) self.assertEqual(str(id2type[21]), "class Test") # Name self (line 4)
def test_calling_user_defined_method(self): class A(): def f(self, x): return x class Test(): def __init__(self): self.a = A() def forward(self, x): return self.a.f(x) id2type = generate_id2type_from_forward(Test(), (1, )) self.assertEqual( str(id2type[1]), "class Test -> int -> int") # FunctionDef forward (line 1) self.assertEqual(str(id2type[7]), "int") # Return self.assertEqual(str(id2type[8]), "int") # Call self.a.f(x) (line 2) self.assertEqual(str(id2type[10]), "class A") # Attribute self.a (line 2) self.assertEqual(str(id2type[11]), "class Test") # Name self (line 2) self.assertEqual(str(id2type[15]), "int") # Name x (line 2) self.assertEqual(str(id2type[23]), "int") # Return self.assertEqual(str(id2type[24]), "int") # Name x (line 2)
def test_list_slice(self): class Test(): def forward(self): x = [0, 1, 2, 3] return x[1:2] id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual( str(id2type[1]), "class Test -> int list") # FunctionDef forward (line 1) self.assertEqual(str(id2type[5]), "NoneType") # Assign self.assertEqual(str(id2type[6]), "int list") # Name x (line 2) self.assertEqual(str(id2type[8]), "int list") # List [0, 1, 2, 3] (line 2) self.assertEqual(str(id2type[9]), "int") # Num 0 (line 2) self.assertEqual(str(id2type[10]), "int") # Num 1 (line 2) self.assertEqual(str(id2type[11]), "int") # Num 2 (line 2) self.assertEqual(str(id2type[12]), "int") # Num 3 (line 2) self.assertEqual(str(id2type[14]), "int list") # Return self.assertEqual(str(id2type[15]), "int list") # Subscript x[1:2:] (line 3) self.assertEqual(str(id2type[16]), "int list") # Name x (line 3) self.assertEqual(str(id2type[19]), "int") # Num 1 (line 3) self.assertEqual(str(id2type[20]), "int") # Num 2 (line 3)
def test_tuple_2(self): class Test(): def forward(self): x = (1, 2) x += (3, ) return x id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual( str(id2type[1]), "class Test -> (int, int, int)") # FunctionDef forward (line 1) self.assertEqual(str(id2type[5]), "NoneType") # Assign (line 2) self.assertEqual(str(id2type[6]), "(int, int)") # Name x (line 2) self.assertEqual(str(id2type[8]), "(int, int)") # Tuple (line 2) self.assertEqual(str(id2type[9]), "int") # Num (line 2) self.assertEqual(str(id2type[10]), "int") # Num (line 2) self.assertEqual(str(id2type[12]), "NoneType") # AugAssign (line 3) self.assertEqual(str(id2type[13]), "(int, int, int)") # Name x (line 3) self.assertEqual(str(id2type[16]), "(int,)") # Tuple (line 3) self.assertEqual(str(id2type[17]), "int") # Num (line 3) self.assertEqual(str(id2type[19]), "(int, int, int)") # Return (line 4) self.assertEqual(str(id2type[20]), "(int, int, int)") # Name x (line 4)
def test_list(self): class Test(): def forward(self): xs = [1, 2, 3] v = [] for i in range(3): v.append(xs[:i]) return v id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual(str(id2type[1]), "class Test -> int list list") # FunctionDef (line 1) self.assertEqual(str(id2type[5]), "NoneType") # Assign (line 2) self.assertEqual(str(id2type[6]), "[int, int, int]") # Name (line 2) self.assertEqual(str(id2type[8]), "[int, int, int]") # List (line 2) self.assertEqual(str(id2type[9]), "int") # Num (line 2) self.assertEqual(str(id2type[10]), "int") # Num (line 2) self.assertEqual(str(id2type[11]), "int") # Num (line 2) self.assertEqual(str(id2type[13]), "NoneType") # Assign (line 3) self.assertEqual(str(id2type[14]), "[]") # Name (line 3) self.assertEqual(str(id2type[16]), "[]") # List (line 3) self.assertEqual(str(id2type[18]), "NoneType") # For (line 4) self.assertEqual(str(id2type[19]), "int") # Name (line 4) self.assertEqual(str(id2type[21]), "int list") # Call (line 4) self.assertEqual(str(id2type[22]), "int -> int list") # Name (line 4) self.assertEqual(str(id2type[24]), "int") # Num (line 4) self.assertEqual(str(id2type[26]), "NoneType") # Call (line 5) self.assertEqual(str(id2type[27]), "int list -> NoneType") # Attribute (line 5) self.assertEqual(str(id2type[28]), "int list list") # Name (line 5) self.assertEqual(str(id2type[31]), "int list") # Subscript (line 5) self.assertEqual(str(id2type[32]), "int list") # Name (line 5) self.assertEqual(str(id2type[35]), "int") # Name (line 5) self.assertEqual(str(id2type[38]), "int list list") # Return (line 6) self.assertEqual(str(id2type[39]), "int list list") # Name (line 6)
def test_mutable_augassign(self): class Test(): def forward(self): x = [1, 2, 3] y = x x += [4] return y id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual(str(id2type[1]), "class Test -> int list") # FunctionDef forward (line 1) self.assertEqual(str(id2type[5]), "NoneType") # Assign (line 2) self.assertEqual(str(id2type[6]), "[int, int, int]") # Name x (line 2) self.assertEqual(str(id2type[8]), "[int, int, int]") # List (line 2) self.assertEqual(str(id2type[9]), "int") # Num (line 2) self.assertEqual(str(id2type[10]), "int") # Num (line 2) self.assertEqual(str(id2type[11]), "int") # Num (line 2) self.assertEqual(str(id2type[13]), "NoneType") # Assign (line 3) self.assertEqual(str(id2type[14]), "int list") # Name y (line 3) self.assertEqual(str(id2type[16]), "int list") # Name x (line 3) self.assertEqual(str(id2type[18]), "NoneType") # AugAssign (line 4) self.assertEqual(str(id2type[19]), "int list") # Name x (line 4) self.assertEqual(str(id2type[21]), "int list -> [int] -> int list") # Add self.assertEqual(str(id2type[22]), "[int]") # List (line 4) self.assertEqual(str(id2type[23]), "int") # Num (line 4) self.assertEqual(str(id2type[25]), "int list") # Return (line 5) self.assertEqual(str(id2type[26]), "int list") # Name y (line 5)
def test_type_hints(self): class Test(): def forward(self, x: types.TyNdarray(np.float32, ('a', 'b'))): h = F.split_axis(x, 2, 1) return h model, forward_args = Test(), (np.zeros((10, 10)).astype(np.float32), ) id2type = generate_id2type_from_forward(model, forward_args) self.assertEqual(str( id2type[1] ), "class Test -> ndarray(float32, (10 (a), 10 (b))) -> (Variable(float32, (10 (a), 5 (b // 2))), Variable(float32, (10 (a), 5 (b // 2))))" ) # FunctionDef forward (line 1) self.assertEqual(str(id2type[20]), "NoneType") # Assign self.assertEqual(str( id2type[21] ), "(Variable(float32, (10 (a), 5 (b // 2))), Variable(float32, (10 (a), 5 (b // 2))))" ) # Name h (line 2) self.assertEqual(str( id2type[23] ), "(Variable(float32, (10 (a), 5 (b // 2))), Variable(float32, (10 (a), 5 (b // 2))))" ) # Call F.split_axis(x, 2, 1) (line 2) self.assertEqual( str(id2type[28]), "ndarray(float32, (10 (a), 10 (b)))") # Name x (line 2) self.assertEqual(str(id2type[30]), "int") # Constant 2 (line 2) self.assertEqual(str(id2type[31]), "int") # Constant 1 (line 2) self.assertEqual(str( id2type[32] ), "(Variable(float32, (10 (a), 5 (b // 2))), Variable(float32, (10 (a), 5 (b // 2))))" ) # Return self.assertEqual(str( id2type[33] ), "(Variable(float32, (10 (a), 5 (b // 2))), Variable(float32, (10 (a), 5 (b // 2))))" ) # Name h (line 3)
def test_calling_user_defined_callable_nested(self): class B(): def f(self): return 1 def __call__(self): return self.f() class Test(): def __init__(self): self.b = B() def forward(self): return self.b() id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual(str(id2type[1]), "class Test -> int") # FunctionDef forward (line 1) self.assertEqual(str(id2type[5]), "int") # Return (line 2) self.assertEqual(str(id2type[6]), "int") # Call (line 2) self.assertEqual(str(id2type[7]), "class B -> int") # Attribute b (line 2) self.assertEqual(str(id2type[8]), "class Test") # Name self (line 2) self.assertEqual(str(id2type[11]), "class B -> int") # FunctionDef __call__ (line 1) self.assertEqual(str(id2type[15]), "int") # Return (line 2) self.assertEqual(str(id2type[16]), "int") # Call (line 2) self.assertEqual(str(id2type[17]), "class B -> int") # Attribute f (line 2) self.assertEqual(str(id2type[18]), "class B") # Name self (line 2) self.assertEqual(str(id2type[21]), "class B -> int") # FunctionDef f (line 1) self.assertEqual(str(id2type[25]), "int") # Return (line 2) self.assertEqual(str(id2type[26]), "int") # Num (line 2)
def test_lazy_init_branch_if(self): type_inference_tools.reset_state() class Test(chainer.Chain): def forward(self, x): if x is None: x = 42 else: x += 1 return x id2type = generate_id2type_from_forward(Test(), (None,)) self.assertEqual(str(id2type[1]), "class Test -> NoneType -> int") # FunctionDef forward (line 1) self.assertEqual(str(id2type[7]), "NoneType") # If self.assertEqual(str(id2type[8]), "bool") # Compare (line 2) self.assertEqual(str(id2type[9]), "int") # Name x (line 2) self.assertEqual(str(id2type[13]), "NoneType") # Assign self.assertEqual(str(id2type[14]), "int") # Name x (line 3) self.assertEqual(str(id2type[16]), "int") # Num 42 (line 3) self.assertEqual(str(id2type[17]), "NoneType") # AugAssign self.assertEqual(str(id2type[18]), "a0 (from line 5)") # Name x (line 5) self.assertEqual(str(id2type[20]), "NoneType -> int -> a0 (from line 5)") # Add self.assertEqual(str(id2type[21]), "int") # Num 1 (line 5) self.assertEqual(str(id2type[22]), "int") # Return self.assertEqual(str(id2type[23]), "int") # Name x (line 6)
def test_num_coercion_if(self): class Test(): def forward(self): a = 1 b = a if True: b = b + 1.0 return a id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual(str(id2type[1]), "class Test -> int") # FunctionDef forward (line 1) self.assertEqual(str(id2type[5]), "NoneType") # Assign (line 2) self.assertEqual(str(id2type[6]), "int") # Name a (line 2) self.assertEqual(str(id2type[8]), "int") # Num (line 2) self.assertEqual(str(id2type[9]), "NoneType") # Assign (line 3) self.assertEqual(str(id2type[10]), "int") # Name b (line 3) self.assertEqual(str(id2type[12]), "int") # Name a (line 3) self.assertEqual(str(id2type[14]), "NoneType") # If (line 4) self.assertEqual(str(id2type[15]), "bool") # NameConstant (line 4) self.assertEqual(str(id2type[16]), "NoneType") # Assign (line 5) self.assertEqual(str(id2type[17]), "float") # Name b (line 5) self.assertEqual(str(id2type[19]), "float") # BinOp (line 5) self.assertEqual(str(id2type[20]), "int") # Name b (line 5) self.assertEqual(str(id2type[22]), "int -> float -> float") # Add self.assertEqual(str(id2type[23]), "float") # Num (line 5) self.assertEqual(str(id2type[24]), "int") # Return (line 6) self.assertEqual(str(id2type[25]), "int") # Name a (line 6)
def test_tuple_coercion(self): class Test(): def forward(self): x = (1, 2, 3) for i in range(1, 3): o = x[i] return o id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual(str(id2type[1]), "class Test -> int") # FunctionDef (line 1) self.assertEqual(str(id2type[5]), "NoneType") # Assign (line 2) self.assertEqual(str(id2type[6]), "(int, int, int)") # Name (line 2) self.assertEqual(str(id2type[8]), "(int, int, int)") # Tuple (line 2) self.assertEqual(str(id2type[9]), "int") # Num (line 2) self.assertEqual(str(id2type[10]), "int") # Num (line 2) self.assertEqual(str(id2type[11]), "int") # Num (line 2) self.assertEqual(str(id2type[13]), "NoneType") # For (line 3) self.assertEqual(str(id2type[14]), "int") # Name (line 3) self.assertEqual(str(id2type[16]), "int list") # Call (line 3) self.assertEqual(str(id2type[17]), "int -> int -> int list") # Name (line 3) self.assertEqual(str(id2type[19]), "int") # Num (line 3) self.assertEqual(str(id2type[20]), "int") # Num (line 3) self.assertEqual(str(id2type[21]), "NoneType") # Assign (line 4) self.assertEqual(str(id2type[22]), "int") # Name (line 4) self.assertEqual(str(id2type[24]), "int") # Subscript (line 4) self.assertEqual(str(id2type[25]), "int tuple") # Name (line 4) self.assertEqual(str(id2type[28]), "int") # Name (line 4) self.assertEqual(str(id2type[31]), "int") # Return (line 5) self.assertEqual(str(id2type[32]), "int") # Name (line 5)
def test_separate(self): class Test(): def forward(self): F.separate(np.zeros((3, 4, 5)), axis=0) id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual( str(id2type[1]), "class Test -> NoneType") # FunctionDef forward (line 1) self.assertEqual(str(id2type[5]), "NoneType") # Expr self.assertEqual( str(id2type[6]), "(Variable(dtype=float64, shape=(4, 5)), Variable(dtype=float64, shape=(4, 5)), Variable(dtype=float64, shape=(4, 5)))" ) # Call F.separate(np.zeros((3, 4, 5)), axis=0) (line 2) self.assertEqual(str( id2type[7] ), "ndarray(dtype=float64, shape=(3, 4, 5)) -> (Variable(dtype=float64, shape=(4, 5)), Variable(dtype=float64, shape=(4, 5)), Variable(dtype=float64, shape=(4, 5)))" ) # Attribute F.separate (line 2) self.assertEqual(str(id2type[11]), "ndarray(dtype=float64, shape=(3, 4, 5))" ) # Call np.zeros((3, 4, 5)) (line 2) self.assertEqual( str(id2type[12]), "(int, int, int) -> ndarray(dtype=float64, shape=(3, 4, 5))" ) # Attribute np.zeros (line 2) self.assertEqual(str(id2type[16]), "(int, int, int)") # Tuple (3, 4, 5) (line 2) self.assertEqual(str(id2type[17]), "int") # Num 3 (line 2) self.assertEqual(str(id2type[18]), "int") # Num 4 (line 2) self.assertEqual(str(id2type[19]), "int") # Num 5 (line 2) self.assertEqual(str(id2type[22]), "int") # Num 0 (line 2)
def test_split_axis(self): class Test(): def forward(self): F.split_axis(np.zeros((3, 4, 5)), 2, 1) id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual( str(id2type[1]), "class Test -> NoneType") # FunctionDef forward (line 1) self.assertEqual(str(id2type[5]), "NoneType") # Expr self.assertEqual( str(id2type[6]), "(Variable(float64, (3, 2, 5)), Variable(float64, (3, 2, 5)))" ) # Call F.split_axis(np.zeros((3, 4, 5)), 2, 1) (line 3) self.assertEqual( str(id2type[11]), "ndarray(float64, (3, 4, 5))") # Call np.zeros((3, 4, 5)) (line 3) self.assertEqual(str(id2type[16]), "(int, int, int)") # Tuple (3, 4, 5) (line 3) self.assertEqual(str(id2type[17]), "int") # Num 3 (line 3) self.assertEqual(str(id2type[18]), "int") # Num 4 (line 3) self.assertEqual(str(id2type[19]), "int") # Num 5 (line 3) self.assertEqual(str(id2type[21]), "int") # Num 2 (line 3) self.assertEqual(str(id2type[22]), "int") # Num 1 (line 3)
def test_for_simple(self): class Test(): def forward(self, x): for i in range(2): x = float(i) + 1 return x id2type = generate_id2type_from_forward(Test(), (0,)) self.assertEqual(str(id2type[1]), "class Test -> float -> float") # FunctionDef forward (line 1) self.assertEqual(str(id2type[7]), "NoneType") # For (line 2) self.assertEqual(str(id2type[8]), "int") # Name i (line 2) self.assertEqual(str(id2type[10]), "int list") # Call (line 2) self.assertEqual(str(id2type[11]), "int -> int list") # Name range (line 2) self.assertEqual(str(id2type[13]), "int") # Num (line 2) self.assertEqual(str(id2type[14]), "NoneType") # Assign (line 3) self.assertEqual(str(id2type[15]), "float") # Name x (line 3) self.assertEqual(str(id2type[17]), "float") # BinOp (line 3) self.assertEqual(str(id2type[18]), "float") # Call (line 3) self.assertEqual(str(id2type[19]), "int -> float") # Name float (line 3) self.assertEqual(str(id2type[21]), "int") # Name i (line 3) self.assertEqual(str(id2type[23]), "float -> int -> float") # Add self.assertEqual(str(id2type[24]), "int") # Num (line 3) self.assertEqual(str(id2type[25]), "float") # Return (line 4) self.assertEqual(str(id2type[26]), "float") # Name x (line 4)
def test_numpy_zeros(self): class Test(): def forward(self): x = np.zeros((3, 3)) y = np.zeros(3, dtype='int64') id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual( str(id2type[1]), "class Test -> NoneType") # FunctionDef forward (line 1) self.assertEqual(str(id2type[5]), "NoneType") # Assign self.assertEqual(str(id2type[6]), "ndarray(float64, (3, 3))") # Name x (line 2) self.assertEqual( str(id2type[8]), "ndarray(float64, (3, 3))") # Call np.zeros((3, 3)) (line 2) self.assertEqual(str(id2type[13]), "(int, int)") # Tuple (3, 3) (line 2) self.assertEqual(str(id2type[14]), "int") # Num 3 (line 2) self.assertEqual(str(id2type[15]), "int") # Num 3 (line 2) self.assertEqual(str(id2type[17]), "NoneType") # Assign self.assertEqual(str(id2type[18]), "ndarray(int64, (3,))") # Name y (line 3) self.assertEqual( str(id2type[20]), "ndarray(int64, (3,))") # Call np.zeros(3dtype='int64') (line 3) self.assertEqual(str(id2type[25]), "int") # Num 3 (line 3) self.assertEqual(str(id2type[27]), "string") # Str 'int64' (line 3)
def test_string(self): class Test(): def forward(self): v = "foobar" for x in ["a", "b"]: v += x return v id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual(str(id2type[1]), "class Test -> string") # FunctionDef (line 1) self.assertEqual(str(id2type[5]), "NoneType") # Assign (line 2) self.assertEqual(str(id2type[6]), "string") # Name (line 2) self.assertEqual(str(id2type[8]), "string") # Str (line 2) self.assertEqual(str(id2type[9]), "NoneType") # For (line 3) self.assertEqual(str(id2type[10]), "string") # Name (line 3) self.assertEqual(str(id2type[12]), "string list") # List (line 3) self.assertEqual(str(id2type[13]), "string") # Str (line 3) self.assertEqual(str(id2type[14]), "string") # Str (line 3) self.assertEqual(str(id2type[16]), "NoneType") # AugAssign (line 4) self.assertEqual(str(id2type[17]), "string") # Name (line 4) self.assertEqual(str(id2type[19]), "string -> string -> string") # Add self.assertEqual(str(id2type[20]), "string") # Name (line 4) self.assertEqual(str(id2type[22]), "string") # Return (line 5) self.assertEqual(str(id2type[23]), "string") # Name (line 5)
def test_list_comprehension(self): class Test(): def f(self, x): return x def forward(self, x): y = [self.f(i) for i in range(x)] return y id2type = generate_id2type_from_forward(Test(), (1,)) self.assertEqual(str(id2type[1]), "class Test -> int -> int list") # FunctionDef forward (line 1) self.assertEqual(str(id2type[7]), "NoneType") # Assign (line 2) self.assertEqual(str(id2type[8]), "int list") # Name y (line 2) self.assertEqual(str(id2type[10]), "int list") # ListComp (line 2) self.assertEqual(str(id2type[11]), "int") # Call self.f(i) (line 2) self.assertEqual(str(id2type[12]), "class Test -> int -> int") # Attribute self.f (line 2) self.assertEqual(str(id2type[13]), "class Test") # Name self (line 2) self.assertEqual(str(id2type[16]), "int") # Name i (line 2) self.assertEqual(str(id2type[19]), "int") # Name i (line 2) self.assertEqual(str(id2type[21]), "int list") # Call range(x) (line 2) self.assertEqual(str(id2type[22]), "int -> int list") # Name range (line 2) self.assertEqual(str(id2type[24]), "int") # Name x (line 2) self.assertEqual(str(id2type[26]), "int list") # Return (line 3) self.assertEqual(str(id2type[27]), "int list") # Name y (line 3) self.assertEqual(str(id2type[29]), "class Test -> int -> int") # FunctionDef f (line 1) self.assertEqual(str(id2type[35]), "int") # Return (line 2) self.assertEqual(str(id2type[36]), "int") # Name x (line 2)
def test_list_of_tuple(self): class Test(): def forward(self, v): for x, y in [(1, 2.0), (2, 3.0)]: v += x + y return v id2type = generate_id2type_from_forward(Test(), (0,)) self.assertEqual(str(id2type[1]), "class Test -> float -> float") # FunctionDef forward (line 1) self.assertEqual(str(id2type[7]), "NoneType") # For self.assertEqual(str(id2type[8]), "(int, float)") # Tuple (x, y) (line 2) self.assertEqual(str(id2type[9]), "int") # Name x (line 2) self.assertEqual(str(id2type[11]), "float") # Name y (line 2) self.assertEqual(str(id2type[14]), "(int, float) list") # List [(1, 2.0), (2, 3.0)] (line 2) self.assertEqual(str(id2type[15]), "(int, float)") # Tuple (1, 2.0) (line 2) self.assertEqual(str(id2type[16]), "int") # Num 1 (line 2) self.assertEqual(str(id2type[17]), "float") # Num 2.0 (line 2) self.assertEqual(str(id2type[19]), "(int, float)") # Tuple (2, 3.0) (line 2) self.assertEqual(str(id2type[20]), "int") # Num 2 (line 2) self.assertEqual(str(id2type[21]), "float") # Num 3.0 (line 2) self.assertEqual(str(id2type[24]), "NoneType") # AugAssign self.assertEqual(str(id2type[25]), "float") # Name v (line 3) self.assertEqual(str(id2type[27]), "float -> float -> float") # Add self.assertEqual(str(id2type[28]), "float") # BinOp x + y (line 3) self.assertEqual(str(id2type[29]), "int") # Name x (line 3) self.assertEqual(str(id2type[31]), "int -> float -> float") # Add self.assertEqual(str(id2type[32]), "float") # Name y (line 3) self.assertEqual(str(id2type[34]), "float") # Return self.assertEqual(str(id2type[35]), "float") # Name v (line 4)
def test_lazy_init_branch_else(self): class Test(chainer.Chain): def forward(self, x): if x is None: x = 42 else: x += 1 return x # XXX: the input is different from the previous one id2type = generate_id2type_from_forward(Test(), (2,)) self.assertEqual(str(id2type[1]), "class Test -> int -> int") # FunctionDef forward (line 1) self.assertEqual(str(id2type[7]), "NoneType") # If self.assertEqual(str(id2type[8]), "bool") # Compare (line 2) self.assertEqual(str(id2type[9]), "int") # Name x (line 2) self.assertEqual(str(id2type[13]), "NoneType") # Assign self.assertEqual(str(id2type[14]), "int") # Name x (line 3) self.assertEqual(str(id2type[16]), "int") # Num 42 (line 3) self.assertEqual(str(id2type[17]), "NoneType") # AugAssign self.assertEqual(str(id2type[18]), "int") # Name x (line 5) self.assertEqual(str(id2type[20]), "int -> int -> int") # Add self.assertEqual(str(id2type[21]), "int") # Num 1 (line 5) self.assertEqual(str(id2type[22]), "int") # Return self.assertEqual(str(id2type[23]), "int") # Name x (line 6)
def test_mutable_attribute_assign(self): class Test(): def __init__(self): self.a = [1, 2, 3] def forward(self): b = self.a b += [4] return self.a id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual(str(id2type[1]), "class Test -> int list") # FunctionDef forward (line 1) self.assertEqual(str(id2type[5]), "NoneType") # Assign (line 2) self.assertEqual(str(id2type[6]), "int list") # Name b (line 2) self.assertEqual(str(id2type[8]), "int list") # Attribute a (line 2) self.assertEqual(str(id2type[9]), "class Test") # Name self (line 2) self.assertEqual(str(id2type[12]), "NoneType") # AugAssign (line 3) self.assertEqual(str(id2type[13]), "int list") # Name b (line 3) self.assertEqual(str(id2type[15]), "int list -> [int] -> int list") # Add self.assertEqual(str(id2type[16]), "[int]") # List (line 3) self.assertEqual(str(id2type[17]), "int") # Num (line 3) self.assertEqual(str(id2type[19]), "int list") # Return (line 4) self.assertEqual(str(id2type[20]), "int list") # Attribute a (line 4) self.assertEqual(str(id2type[21]), "class Test") # Name self (line 4)
def test_calling_user_defined_callable_class(self): class B(): def __call__(self): return 1 class Test(): def __init__(self): self.b = B() def forward(self, x): return self.b() + x id2type = generate_id2type_from_forward(Test(), (1,)) self.assertEqual(str(id2type[1]), "class Test -> int -> int") # FunctionDef forward (line 1) self.assertEqual(str(id2type[7]), "int") # Return (line 2) self.assertEqual(str(id2type[8]), "int") # BinOp (line 2) self.assertEqual(str(id2type[9]), "int") # Call (line 2) self.assertEqual(str(id2type[10]), "class B -> int") # Attribute b (line 2) self.assertEqual(str(id2type[11]), "class Test") # Name self (line 2) self.assertEqual(str(id2type[14]), "int -> int -> int") # Add self.assertEqual(str(id2type[15]), "int") # Name x (line 2) self.assertEqual(str(id2type[17]), "class B -> int") # FunctionDef __call__ (line 1) self.assertEqual(str(id2type[21]), "int") # Return (line 2) self.assertEqual(str(id2type[22]), "int") # Num (line 2)
def test_vstack(self): class Test(): def forward(self): F.vstack([np.zeros((1, 3, 4)), np.zeros((2, 3, 4))]) id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual( str(id2type[1]), "class Test -> NoneType") # FunctionDef forward (line 1) self.assertEqual(str(id2type[5]), "NoneType") # Expr self.assertEqual( str(id2type[6]), "Variable(float64, (None, 3, 4))" ) # Call F.vstack([np.zeros((1, 3, 4)), np.zeros((2, 3, 4))]) (line 2) self.assertEqual( str(id2type[11]), "ndarray(float64, (None, 3, 4)) list" ) # List [np.zeros((1, 3, 4)), np.zeros((2, 3, 4))] (line 2) self.assertEqual( str(id2type[12]), "ndarray(float64, (1, 3, 4))") # Call np.zeros((1, 3, 4)) (line 2) self.assertEqual(str(id2type[17]), "(int, int, int)") # Tuple (1, 3, 4) (line 2) self.assertEqual(str(id2type[18]), "int") # Num 1 (line 2) self.assertEqual(str(id2type[19]), "int") # Num 3 (line 2) self.assertEqual(str(id2type[20]), "int") # Num 4 (line 2) self.assertEqual( str(id2type[22]), "ndarray(float64, (2, 3, 4))") # Call np.zeros((2, 3, 4)) (line 2) self.assertEqual(str(id2type[27]), "(int, int, int)") # Tuple (2, 3, 4) (line 2) self.assertEqual(str(id2type[28]), "int") # Num 2 (line 2) self.assertEqual(str(id2type[29]), "int") # Num 3 (line 2) self.assertEqual(str(id2type[30]), "int") # Num 4 (line 2)
def test_numpy_array(self): class Test(): def forward(self): x = np.array([0]) y = np.array(0, dtype=np.float64) z = np.array([0], dtype='float32') w = np.zeros(0).astype('f') u = np.zeros(0).astype(np.int32) id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual( str(id2type[1]), "class Test -> NoneType") # FunctionDef forward (line 1) self.assertEqual(str(id2type[5]), "NoneType") # Assign self.assertEqual(str(id2type[6]), "ndarray(int64, (None,))") # Name x (line 2) self.assertEqual( str(id2type[8]), "ndarray(int64, (None,))") # Call np.array([0]) (line 2) self.assertEqual(str(id2type[13]), "int list") # List [0] (line 2) self.assertEqual(str(id2type[14]), "int") # Num 0 (line 2) self.assertEqual(str(id2type[16]), "NoneType") # Assign self.assertEqual(str(id2type[17]), "ndarray(float64, ())") # Name y (line 3) self.assertEqual(str(id2type[19]), "ndarray(float64, ())" ) # Call np.array(0, dtype=np.float64) (line 3) self.assertEqual(str(id2type[24]), "int") # Num 0 (line 3) self.assertEqual(str(id2type[26]), "dtype(float64)") # Attribute np.float64 (line 3) self.assertEqual(str(id2type[30]), "NoneType") # Assign self.assertEqual(str(id2type[31]), "ndarray(float32, (None,))") # Name z (line 4) self.assertEqual(str(id2type[33]), "ndarray(float32, (None,))" ) # Call np.array([0], dtype='float32') (line 4) self.assertEqual(str(id2type[38]), "int list") # List [0] (line 4) self.assertEqual(str(id2type[39]), "int") # Num 0 (line 4) self.assertEqual(str(id2type[42]), "string") # Str 'float32' (line 4) self.assertEqual(str(id2type[43]), "NoneType") # Assign self.assertEqual(str(id2type[44]), "ndarray(float32, (0,))") # Name w (line 5) self.assertEqual( str(id2type[46]), "ndarray(float32, (0,))") # Call np.zeros(0).astype('f') (line 5) self.assertEqual(str(id2type[48]), "ndarray(float64, (0,))") # Call np.zeros(0) (line 5) self.assertEqual(str(id2type[53]), "int") # Num 0 (line 5) self.assertEqual(str(id2type[55]), "string") # Str 'f' (line 5) self.assertEqual(str(id2type[56]), "NoneType") # Assign self.assertEqual(str(id2type[57]), "ndarray(int32, (0,))") # Name u (line 6) self.assertEqual(str(id2type[59]), "ndarray(int32, (0,))" ) # Call np.zeros(0).astype(np.int32) (line 6) self.assertEqual(str(id2type[61]), "ndarray(float64, (0,))") # Call np.zeros(0) (line 6) self.assertEqual(str(id2type[66]), "int") # Num 0 (line 6) self.assertEqual(str(id2type[68]), "dtype(int32)") # Attribute np.int32 (line 6)
def test_sum(self): class Test(): def forward(self): F.sum(np.zeros((1, 2, 3)), axis=-1) F.sum(np.zeros((1, 2, 3)), axis=1, keepdims=True) F.sum(np.zeros((1, 2, 3)), keepdims=True) id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual( str(id2type[1]), "class Test -> NoneType") # FunctionDef forward (line 1) self.assertEqual(str(id2type[5]), "NoneType") # Expr self.assertEqual(str(id2type[6]), "Variable(float64, (1, 2))" ) # Call F.sum(np.zeros((1, 2, 3)), axis=-1) (line 3) self.assertEqual( str(id2type[11]), "ndarray(float64, (1, 2, 3))") # Call np.zeros((1, 2, 3)) (line 3) self.assertEqual(str(id2type[16]), "(int, int, int)") # Tuple (1, 2, 3) (line 3) self.assertEqual(str(id2type[17]), "int") # Num 1 (line 3) self.assertEqual(str(id2type[18]), "int") # Num 2 (line 3) self.assertEqual(str(id2type[19]), "int") # Num 3 (line 3) self.assertEqual(str(id2type[22]), "int") # UnaryOp -1 (line 3) self.assertEqual(str(id2type[24]), "int") # Num 1 (line 3) self.assertEqual(str(id2type[25]), "NoneType") # Expr self.assertEqual( str(id2type[26]), "Variable(float64, (1, 1, 3))" ) # Call F.sum(np.zeros((1, 2, 3)), axis=1, keepdims=True) (line 4) self.assertEqual( str(id2type[31]), "ndarray(float64, (1, 2, 3))") # Call np.zeros((1, 2, 3)) (line 4) self.assertEqual(str(id2type[36]), "(int, int, int)") # Tuple (1, 2, 3) (line 4) self.assertEqual(str(id2type[37]), "int") # Num 1 (line 4) self.assertEqual(str(id2type[38]), "int") # Num 2 (line 4) self.assertEqual(str(id2type[39]), "int") # Num 3 (line 4) self.assertEqual(str(id2type[42]), "int") # Num 1 (line 4) self.assertEqual(str(id2type[44]), "bool") # NameConstant True (line 4) self.assertEqual(str(id2type[45]), "NoneType") # Expr self.assertEqual( str(id2type[46]), "Variable(float64, (1, 1, 1))" ) # Call F.sum(np.zeros((1, 2, 3)), keepdims=True) (line 5) self.assertEqual( str(id2type[51]), "ndarray(float64, (1, 2, 3))") # Call np.zeros((1, 2, 3)) (line 5) self.assertEqual(str(id2type[56]), "(int, int, int)") # Tuple (1, 2, 3) (line 5) self.assertEqual(str(id2type[57]), "int") # Num 1 (line 5) self.assertEqual(str(id2type[58]), "int") # Num 2 (line 5) self.assertEqual(str(id2type[59]), "int") # Num 3 (line 5) self.assertEqual(str(id2type[62]), "bool") # NameConstant True (line 5)
def test_expand_dims(self): class Test(): def forward(self): F.expand_dims(np.zeros((2, 3, 4)), 1) F.expand_dims(np.zeros((2, 3, 4)), -2) id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual( str(id2type[1]), "class Test -> NoneType") # FunctionDef forward (line 1) self.assertEqual(str(id2type[5]), "NoneType") # Expr self.assertEqual( str(id2type[6]), "Variable(dtype=float64, shape=(2, 1, 3, 4))" ) # Call F.expand_dims(np.zeros((2, 3, 4)), 1) (line 2) self.assertEqual(str( id2type[7] ), "ndarray(dtype=float64, shape=(2, 3, 4)) -> int -> Variable(dtype=float64, shape=(2, 1, 3, 4))" ) # Attribute F.expand_dims (line 2) self.assertEqual(str(id2type[11]), "ndarray(dtype=float64, shape=(2, 3, 4))" ) # Call np.zeros((2, 3, 4)) (line 2) self.assertEqual( str(id2type[12]), "(int, int, int) -> ndarray(dtype=float64, shape=(2, 3, 4))" ) # Attribute np.zeros (line 2) self.assertEqual(str(id2type[16]), "(int, int, int)") # Tuple (2, 3, 4) (line 2) self.assertEqual(str(id2type[17]), "int") # Num 2 (line 2) self.assertEqual(str(id2type[18]), "int") # Num 3 (line 2) self.assertEqual(str(id2type[19]), "int") # Num 4 (line 2) self.assertEqual(str(id2type[21]), "int") # Num 1 (line 2) self.assertEqual(str(id2type[22]), "NoneType") # Expr self.assertEqual( str(id2type[23]), "Variable(dtype=float64, shape=(2, 3, 1, 4))" ) # Call F.expand_dims(np.zeros((2, 3, 4)), -2) (line 3) self.assertEqual(str( id2type[24] ), "ndarray(dtype=float64, shape=(2, 3, 4)) -> int -> Variable(dtype=float64, shape=(2, 3, 1, 4))" ) # Attribute F.expand_dims (line 3) self.assertEqual(str(id2type[28]), "ndarray(dtype=float64, shape=(2, 3, 4))" ) # Call np.zeros((2, 3, 4)) (line 3) self.assertEqual( str(id2type[29]), "(int, int, int) -> ndarray(dtype=float64, shape=(2, 3, 4))" ) # Attribute np.zeros (line 3) self.assertEqual(str(id2type[33]), "(int, int, int)") # Tuple (2, 3, 4) (line 3) self.assertEqual(str(id2type[34]), "int") # Num 2 (line 3) self.assertEqual(str(id2type[35]), "int") # Num 3 (line 3) self.assertEqual(str(id2type[36]), "int") # Num 4 (line 3) self.assertEqual(str(id2type[38]), "int") # UnaryOp -2 (line 3) self.assertEqual(str(id2type[40]), "int") # Num 2 (line 3)
def test_squeeze(self): class Test(): def forward(self): F.squeeze(np.zeros((2, 1, 1, 3))) F.squeeze(np.zeros((2, 1, 1, 3)), axis=2) F.squeeze(np.zeros((2, 1, 1, 3)), axis=(1, 2)) id2type = generate_id2type_from_forward(Test(), ()) self.assertEqual( str(id2type[1]), "class Test -> NoneType") # FunctionDef forward (line 1) self.assertEqual(str(id2type[5]), "NoneType") # Expr self.assertEqual(str(id2type[6]), "Variable(float64, (2, 3))" ) # Call F.squeeze(np.zeros((2, 1, 1, 3))) (line 2) self.assertEqual(str(id2type[11]), "ndarray(float64, (2, 1, 1, 3))" ) # Call np.zeros((2, 1, 1, 3)) (line 2) self.assertEqual(str(id2type[16]), "(int, int, int, int)") # Tuple (2, 1, 1, 3) (line 2) self.assertEqual(str(id2type[17]), "int") # Num 2 (line 2) self.assertEqual(str(id2type[18]), "int") # Num 1 (line 2) self.assertEqual(str(id2type[19]), "int") # Num 1 (line 2) self.assertEqual(str(id2type[20]), "int") # Num 3 (line 2) self.assertEqual(str(id2type[22]), "NoneType") # Expr self.assertEqual( str(id2type[23]), "Variable(float64, (2, 1, 3))" ) # Call F.squeeze(np.zeros((2, 1, 1, 3)), axis=2) (line 3) self.assertEqual(str(id2type[28]), "ndarray(float64, (2, 1, 1, 3))" ) # Call np.zeros((2, 1, 1, 3)) (line 3) self.assertEqual(str(id2type[33]), "(int, int, int, int)") # Tuple (2, 1, 1, 3) (line 3) self.assertEqual(str(id2type[34]), "int") # Num 2 (line 3) self.assertEqual(str(id2type[35]), "int") # Num 1 (line 3) self.assertEqual(str(id2type[36]), "int") # Num 1 (line 3) self.assertEqual(str(id2type[37]), "int") # Num 3 (line 3) self.assertEqual(str(id2type[40]), "int") # Num 2 (line 3) self.assertEqual(str(id2type[41]), "NoneType") # Expr self.assertEqual( str(id2type[42]), "Variable(float64, (2, 3))" ) # Call F.squeeze(np.zeros((2, 1, 1, 3)), axis=(1, 2)) (line 4) self.assertEqual(str(id2type[47]), "ndarray(float64, (2, 1, 1, 3))" ) # Call np.zeros((2, 1, 1, 3)) (line 4) self.assertEqual(str(id2type[52]), "(int, int, int, int)") # Tuple (2, 1, 1, 3) (line 4) self.assertEqual(str(id2type[53]), "int") # Num 2 (line 4) self.assertEqual(str(id2type[54]), "int") # Num 1 (line 4) self.assertEqual(str(id2type[55]), "int") # Num 1 (line 4) self.assertEqual(str(id2type[56]), "int") # Num 3 (line 4) self.assertEqual(str(id2type[59]), "(int, int)") # Tuple (1, 2) (line 4) self.assertEqual(str(id2type[60]), "int") # Num 1 (line 4) self.assertEqual(str(id2type[61]), "int") # Num 2 (line 4)
def test_MNIST(self): type_inference_tools.reset_state() model, forward_args = gen_MNIST_model() id2type = generate_id2type_from_forward(model, forward_args) # === BEGIN ASSERTIONS for MNIST === # === function forward === self.assertEqual(str(id2type[8]), "torch.Tensor(float32, (64, 32, 26, 26))") # Name x (line 2) self.assertEqual(str(id2type[10]), "torch.Tensor(float32, (64, 32, 26, 26))") # Call self.conv1(x) (line 2) self.assertEqual(str(id2type[12]), "class Net") # Name self (line 2) self.assertEqual(str(id2type[15]), "torch.Tensor(float32, (64, 1, 28, 28))") # Name x (line 2) self.assertEqual(str(id2type[18]), "torch.Tensor(float32, (64, 32, 26, 26))") # Name x (line 3) self.assertEqual(str(id2type[20]), "torch.Tensor(float32, (64, 32, 26, 26))") # Call F.relu(x) (line 3) self.assertEqual(str(id2type[25]), "torch.Tensor(float32, (64, 32, 26, 26))") # Name x (line 3) self.assertEqual(str(id2type[28]), "torch.Tensor(float32, (64, 64, 24, 24))") # Name x (line 4) self.assertEqual(str(id2type[30]), "torch.Tensor(float32, (64, 64, 24, 24))") # Call self.conv2(x) (line 4) self.assertEqual(str(id2type[32]), "class Net") # Name self (line 4) self.assertEqual(str(id2type[35]), "torch.Tensor(float32, (64, 32, 26, 26))") # Name x (line 4) self.assertEqual(str(id2type[38]), "torch.Tensor(float32, (64, 64, 12, 12))") # Name x (line 5) self.assertEqual(str(id2type[40]), "torch.Tensor(float32, (64, 64, 12, 12))") # Call F.max_pool2d(x, 2) (line 5) self.assertEqual(str(id2type[45]), "torch.Tensor(float32, (64, 64, 24, 24))") # Name x (line 5) self.assertEqual(str(id2type[47]), "int") # Constant 2 (line 5) self.assertEqual(str(id2type[49]), "torch.Tensor(float32, (64, 64, 12, 12))") # Name x (line 6) self.assertEqual(str(id2type[51]), "torch.Tensor(float32, (64, 64, 12, 12))") # Call self.dropout1(x) (line 6) self.assertEqual(str(id2type[53]), "class Net") # Name self (line 6) self.assertEqual(str(id2type[56]), "torch.Tensor(float32, (64, 64, 12, 12))") # Name x (line 6) self.assertEqual(str(id2type[59]), "torch.Tensor(float32, (64, 9216))") # Name x (line 7) self.assertEqual(str(id2type[61]), "torch.Tensor(float32, (64, 9216))") # Call torch.flatten(x, start_dim=1) (line 7) self.assertEqual(str(id2type[66]), "torch.Tensor(float32, (64, 64, 12, 12))") # Name x (line 7) self.assertEqual(str(id2type[69]), "int") # Constant 1 (line 7) self.assertEqual(str(id2type[71]), "torch.Tensor(float32, (64, 128))") # Name x (line 8) self.assertEqual(str(id2type[73]), "torch.Tensor(float32, (64, 128))") # Call self.fc1(x) (line 8) self.assertEqual(str(id2type[75]), "class Net") # Name self (line 8) self.assertEqual(str(id2type[78]), "torch.Tensor(float32, (64, 9216))") # Name x (line 8) self.assertEqual(str(id2type[81]), "torch.Tensor(float32, (64, 128))") # Name x (line 9) self.assertEqual(str(id2type[83]), "torch.Tensor(float32, (64, 128))") # Call F.relu(x) (line 9) self.assertEqual(str(id2type[88]), "torch.Tensor(float32, (64, 128))") # Name x (line 9) self.assertEqual(str(id2type[91]), "torch.Tensor(float32, (64, 128))") # Name x (line 10) self.assertEqual(str(id2type[93]), "torch.Tensor(float32, (64, 128))") # Call self.dropout2(x) (line 10) self.assertEqual(str(id2type[95]), "class Net") # Name self (line 10) self.assertEqual(str(id2type[98]), "torch.Tensor(float32, (64, 128))") # Name x (line 10) self.assertEqual(str(id2type[101]), "torch.Tensor(float32, (64, 10))") # Name x (line 11) self.assertEqual(str(id2type[103]), "torch.Tensor(float32, (64, 10))") # Call self.fc2(x) (line 11) self.assertEqual(str(id2type[105]), "class Net") # Name self (line 11) self.assertEqual(str(id2type[108]), "torch.Tensor(float32, (64, 128))") # Name x (line 11) self.assertEqual(str(id2type[111]), "torch.Tensor(float32, (64, 10))") # Name output (line 12) self.assertEqual(str(id2type[113]), "torch.Tensor(float32, (64, 10))") # Call F.log_softmax(x, dim=1) (line 12) self.assertEqual(str(id2type[118]), "torch.Tensor(float32, (64, 10))") # Name x (line 12) self.assertEqual(str(id2type[121]), "int") # Constant 1 (line 12) self.assertEqual(str(id2type[123]), "torch.Tensor(float32, (64, 10))") # Name output (line 13)
def test_num_bool(self): class Test(): def forward(self, x, y): return x and y or True id2type = generate_id2type_from_forward(Test(), (True, False)) self.assertEqual(str(id2type[1]), "class Test -> bool -> bool -> bool") # FunctionDef (line 1) self.assertEqual(str(id2type[9]), "bool") # Return (line 2) self.assertEqual(str(id2type[10]), "bool") # BoolOp (line 2) self.assertEqual(str(id2type[11]), "bool -> bool -> bool") # Or self.assertEqual(str(id2type[12]), "bool") # BoolOp (line 2) self.assertEqual(str(id2type[13]), "bool -> bool -> bool") # And self.assertEqual(str(id2type[14]), "bool") # Name (line 2) self.assertEqual(str(id2type[16]), "bool") # Name (line 2) self.assertEqual(str(id2type[18]), "bool") # NameConstant (line 2)