def test_simple_model_train(self): with jt.flag_scope(trace_py_var=2): model = Model(input_size=1) opt = jt.optim.SGD(model.parameters(), 0.1) batch_size = 10 x = jt.float32(np.random.rand(batch_size, 1)) y = model(x) opt.step(y**2) jt.sync_all() data = jt.dump_trace_data() jt.clear_trace_data() # print_stack_tree(data) for k, v in data["execute_op_info"].items(): for i in v['fused_ops']: if i not in data["node_data"]: assert 0, (i, "not found") for k, v in list(data["node_data"].items()): if v["attrs"]["name"] == "unname": assert 0 print(len(data["node_data"])) with open(f"{jt.flags.cache_path}/simple_model_train.pkl", "wb") as f: pickle.dump(data, f)
def test_resnet_infer_with_feature(self): cat_url = "https://ss1.bdstatic.com/70cFuXSh_Q1YnxGkpoWK1HF6hhy/it/u=3782485413,1118109468&fm=26&gp=0.jpg" import jittor_utils cat_path = f"{jt.flags.cache_path}/cat.jpg" print("download") jittor_utils.download(cat_url, cat_path) with open(cat_path, 'rb') as f: img = Image.open(f).convert('RGB') img = jt.array(np.array(img)) print(img.shape, img.dtype) img = ((img.float() - 128) / 255).transpose(2, 0, 1) with jt.flag_scope(trace_py_var=2, trace_var_data=1): img = img[None, ...] resnet18 = resnet.Resnet18(pretrained=True) x = jt.float32(img) y = resnet18(x) y.sync() data = jt.dump_trace_data() jt.clear_trace_data() with open(f"{jt.flags.cache_path}/resnet_with_feature.pkl", "wb") as f: pickle.dump(data, f) for k, v in data["execute_op_info"].items(): for i in v['fused_ops']: if i not in data["node_data"]: assert 0, (i, "not found")
def test_resnet(self): with jt.flag_scope(trace_py_var=2): resnet18 = resnet.Resnet18() x = jt.float32(np.random.rand(2, 3, 224, 224)) y = resnet18(x) y.sync() data = jt.dump_trace_data() jt.clear_trace_data()
def test_simple_model(self): with jt.flag_scope(trace_py_var=2): model = Model(input_size=1) batch_size = 10 x = jt.float32(np.random.rand(batch_size, 1)) y = model(x) y.sync() data = jt.dump_trace_data() jt.clear_trace_data()
def test_resnet_train(self): with jt.flag_scope(trace_py_var=2): resnet18 = resnet.Resnet18() opt = jt.optim.SGD(resnet18.parameters(), 0.1) x = jt.float32(np.random.rand(2, 3, 224, 224)) y = resnet18(x) opt.step(y**2) jt.sync_all() data = jt.dump_trace_data() jt.clear_trace_data()
def test_simple_model(self): with jt.flag_scope(trace_py_var=2): model = Model(input_size=1) batch_size = 10 x = jt.float32(np.random.rand(batch_size, 1)) y = model(x) y.sync() data = jt.dump_trace_data() jt.clear_trace_data() with open(f"{jt.flags.cache_path}/simple_model.pkl", "wb") as f: pickle.dump(data, f)
def test_simple_model_train(self): with jt.flag_scope(trace_py_var=2): model = Model(input_size=1) opt = jt.optim.SGD(model.parameters(), 0.1) batch_size = 10 x = jt.float32(np.random.rand(batch_size, 1)) y = model(x) opt.step(y**2) jt.sync_all() data = jt.dump_trace_data() jt.clear_trace_data()
def test_resnet_infer(self): with jt.flag_scope(trace_py_var=2): resnet18 = resnet.Resnet18() x = jt.float32(np.random.rand(2, 3, 224, 224)) y = resnet18(x) y.sync() data = jt.dump_trace_data() jt.clear_trace_data() with open(f"{jt.flags.cache_path}/resnet.pkl", "wb") as f: pickle.dump(data, f) for k, v in data["execute_op_info"].items(): for i in v['fused_ops']: if i not in data["node_data"]: assert 0, (i, "not found")
def test_resnet_trainx(self): with jt.flag_scope(trace_py_var=2): resnet18 = resnet.Resnet18() opt = jt.optim.SGD(resnet18.parameters(), 0.1) x = jt.float32(np.random.rand(2, 3, 224, 224)) y = resnet18(x) opt.step(y**2) jt.sync_all() data = jt.dump_trace_data() jt.clear_trace_data() with open(f"{jt.flags.cache_path}/resnet_train.pkl", "wb") as f: pickle.dump(data, f) for k, v in data["execute_op_info"].items(): for i in v['fused_ops']: if i not in data["node_data"]: assert 0, (i, "not found") for k, v in data["node_data"].items(): if 'name' not in v["attrs"]: print(v)