Beispiel #1
0
    def test_load_pth(self):
        # TODO: load torch model params
        # define input img
        img = np.random.random((1,3,224,224)).astype("float32")
        jt_img = jt.array(img)
        torch_img = torch.Tensor(img)
        # define pytorch and jittor pretrained model
        torch_model = tv.models.resnet18(True)

        jt_model = resnet.Resnet18()
        jt_model.load_parameters(torch_model.state_dict())
        # todo: model.train() model.eval()

        # output
        jt_out = jt_model(jt_img)
        torch_out = torch_model(torch_img)
        print(np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())))
        assert np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())) < 1e-3

        pth_name = os.path.join(jt.flags.cache_path, "x.pth")
        torch.save(torch_model.state_dict, pth_name)
        jt_model.load(pth_name)

        # output
        jt_out = jt_model(jt_img)
        # torch_out = torch_model(torch_img)
        print(np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())))
        assert np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())) < 1e-3
Beispiel #2
0
    def test_resnet_infer_with_feature(self):
        cat_url = "https://ss1.bdstatic.com/70cFuXSh_Q1YnxGkpoWK1HF6hhy/it/u=3782485413,1118109468&fm=26&gp=0.jpg"
        import jittor_utils
        cat_path = f"{jt.flags.cache_path}/cat.jpg"
        print("download")
        jittor_utils.download(cat_url, cat_path)
        with open(cat_path, 'rb') as f:
            img = Image.open(f).convert('RGB')
            img = jt.array(np.array(img))
            print(img.shape, img.dtype)
            img = ((img.float() - 128) / 255).transpose(2, 0, 1)

        with jt.flag_scope(trace_py_var=2, trace_var_data=1):
            img = img[None, ...]

            resnet18 = resnet.Resnet18(pretrained=True)
            x = jt.float32(img)
            y = resnet18(x)
            y.sync()

            data = jt.dump_trace_data()
            jt.clear_trace_data()
            with open(f"{jt.flags.cache_path}/resnet_with_feature.pkl",
                      "wb") as f:
                pickle.dump(data, f)
            for k, v in data["execute_op_info"].items():
                for i in v['fused_ops']:
                    if i not in data["node_data"]:
                        assert 0, (i, "not found")
Beispiel #3
0
 def __init__(self, latent_dim, input_shape):
     super(Encoder, self).__init__()
     resnet18_model = resnet.Resnet18()
     self.feature_extractor = nn.Sequential(*list(resnet18_model.children())[:(- 3)])
     self.pooling = nn.Pool(kernel_size=8, stride=8, padding=0, op='mean')
     self.fc_mu = nn.Linear(256, latent_dim)
     self.fc_logvar = nn.Linear(256, latent_dim)
     for m in self.modules():
         weights_init_normal(m)
Beispiel #4
0
    def test_resnet(self):
        with jt.flag_scope(trace_py_var=2):

            resnet18 = resnet.Resnet18()
            x = jt.float32(np.random.rand(2, 3, 224, 224))
            y = resnet18(x)
            y.sync()

            data = jt.dump_trace_data()
            jt.clear_trace_data()
Beispiel #5
0
    def test_resnet_train_profile(self):
        with jt.profile_scope(trace_py_var=1):

            resnet18 = resnet.Resnet18()
            opt = jt.optim.SGD(resnet18.parameters(), 0.1)
            x = jt.float32(np.random.rand(2, 3, 224, 224))
            y = resnet18(x)

            opt.step(y**2)
            jt.sync_all()
Beispiel #6
0
    def test_resnet_train(self):
        with jt.flag_scope(trace_py_var=2):

            resnet18 = resnet.Resnet18()
            opt = jt.optim.SGD(resnet18.parameters(), 0.1)
            x = jt.float32(np.random.rand(2, 3, 224, 224))
            y = resnet18(x)

            opt.step(y**2)
            jt.sync_all()

            data = jt.dump_trace_data()
            jt.clear_trace_data()
Beispiel #7
0
    def test_resnet_infer(self):
        with jt.flag_scope(trace_py_var=2):

            resnet18 = resnet.Resnet18()
            x = jt.float32(np.random.rand(2, 3, 224, 224))
            y = resnet18(x)
            y.sync()

            data = jt.dump_trace_data()
            jt.clear_trace_data()
            with open(f"{jt.flags.cache_path}/resnet.pkl", "wb") as f:
                pickle.dump(data, f)
            for k, v in data["execute_op_info"].items():
                for i in v['fused_ops']:
                    if i not in data["node_data"]:
                        assert 0, (i, "not found")
Beispiel #8
0
    def test_resnet_trainx(self):
        with jt.flag_scope(trace_py_var=2):

            resnet18 = resnet.Resnet18()
            opt = jt.optim.SGD(resnet18.parameters(), 0.1)
            x = jt.float32(np.random.rand(2, 3, 224, 224))
            y = resnet18(x)

            opt.step(y**2)
            jt.sync_all()

            data = jt.dump_trace_data()
            jt.clear_trace_data()
            with open(f"{jt.flags.cache_path}/resnet_train.pkl", "wb") as f:
                pickle.dump(data, f)
            for k, v in data["execute_op_info"].items():
                for i in v['fused_ops']:
                    if i not in data["node_data"]:
                        assert 0, (i, "not found")
            for k, v in data["node_data"].items():
                if 'name' not in v["attrs"]:
                    print(v)
Beispiel #9
0
 def __init__(self):
     self.model = resnet.Resnet18()
     self.layer = nn.Linear(1000, 10)