def test_load_pth(self): # TODO: load torch model params # define input img img = np.random.random((1,3,224,224)).astype("float32") jt_img = jt.array(img) torch_img = torch.Tensor(img) # define pytorch and jittor pretrained model torch_model = tv.models.resnet18(True) jt_model = resnet.Resnet18() jt_model.load_parameters(torch_model.state_dict()) # todo: model.train() model.eval() # output jt_out = jt_model(jt_img) torch_out = torch_model(torch_img) print(np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy()))) assert np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())) < 1e-3 pth_name = os.path.join(jt.flags.cache_path, "x.pth") torch.save(torch_model.state_dict, pth_name) jt_model.load(pth_name) # output jt_out = jt_model(jt_img) # torch_out = torch_model(torch_img) print(np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy()))) assert np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())) < 1e-3
def test_resnet_infer_with_feature(self): cat_url = "https://ss1.bdstatic.com/70cFuXSh_Q1YnxGkpoWK1HF6hhy/it/u=3782485413,1118109468&fm=26&gp=0.jpg" import jittor_utils cat_path = f"{jt.flags.cache_path}/cat.jpg" print("download") jittor_utils.download(cat_url, cat_path) with open(cat_path, 'rb') as f: img = Image.open(f).convert('RGB') img = jt.array(np.array(img)) print(img.shape, img.dtype) img = ((img.float() - 128) / 255).transpose(2, 0, 1) with jt.flag_scope(trace_py_var=2, trace_var_data=1): img = img[None, ...] resnet18 = resnet.Resnet18(pretrained=True) x = jt.float32(img) y = resnet18(x) y.sync() data = jt.dump_trace_data() jt.clear_trace_data() with open(f"{jt.flags.cache_path}/resnet_with_feature.pkl", "wb") as f: pickle.dump(data, f) for k, v in data["execute_op_info"].items(): for i in v['fused_ops']: if i not in data["node_data"]: assert 0, (i, "not found")
def __init__(self, latent_dim, input_shape): super(Encoder, self).__init__() resnet18_model = resnet.Resnet18() self.feature_extractor = nn.Sequential(*list(resnet18_model.children())[:(- 3)]) self.pooling = nn.Pool(kernel_size=8, stride=8, padding=0, op='mean') self.fc_mu = nn.Linear(256, latent_dim) self.fc_logvar = nn.Linear(256, latent_dim) for m in self.modules(): weights_init_normal(m)
def test_resnet(self): with jt.flag_scope(trace_py_var=2): resnet18 = resnet.Resnet18() x = jt.float32(np.random.rand(2, 3, 224, 224)) y = resnet18(x) y.sync() data = jt.dump_trace_data() jt.clear_trace_data()
def test_resnet_train_profile(self): with jt.profile_scope(trace_py_var=1): resnet18 = resnet.Resnet18() opt = jt.optim.SGD(resnet18.parameters(), 0.1) x = jt.float32(np.random.rand(2, 3, 224, 224)) y = resnet18(x) opt.step(y**2) jt.sync_all()
def test_resnet_train(self): with jt.flag_scope(trace_py_var=2): resnet18 = resnet.Resnet18() opt = jt.optim.SGD(resnet18.parameters(), 0.1) x = jt.float32(np.random.rand(2, 3, 224, 224)) y = resnet18(x) opt.step(y**2) jt.sync_all() data = jt.dump_trace_data() jt.clear_trace_data()
def test_resnet_infer(self): with jt.flag_scope(trace_py_var=2): resnet18 = resnet.Resnet18() x = jt.float32(np.random.rand(2, 3, 224, 224)) y = resnet18(x) y.sync() data = jt.dump_trace_data() jt.clear_trace_data() with open(f"{jt.flags.cache_path}/resnet.pkl", "wb") as f: pickle.dump(data, f) for k, v in data["execute_op_info"].items(): for i in v['fused_ops']: if i not in data["node_data"]: assert 0, (i, "not found")
def test_resnet_trainx(self): with jt.flag_scope(trace_py_var=2): resnet18 = resnet.Resnet18() opt = jt.optim.SGD(resnet18.parameters(), 0.1) x = jt.float32(np.random.rand(2, 3, 224, 224)) y = resnet18(x) opt.step(y**2) jt.sync_all() data = jt.dump_trace_data() jt.clear_trace_data() with open(f"{jt.flags.cache_path}/resnet_train.pkl", "wb") as f: pickle.dump(data, f) for k, v in data["execute_op_info"].items(): for i in v['fused_ops']: if i not in data["node_data"]: assert 0, (i, "not found") for k, v in data["node_data"].items(): if 'name' not in v["attrs"]: print(v)
def __init__(self): self.model = resnet.Resnet18() self.layer = nn.Linear(1000, 10)