Ejemplo n.º 1
0
    def test_simple_model_train(self):
        with jt.flag_scope(trace_py_var=2):

            model = Model(input_size=1)
            opt = jt.optim.SGD(model.parameters(), 0.1)

            batch_size = 10
            x = jt.float32(np.random.rand(batch_size, 1))
            y = model(x)
            opt.step(y**2)
            jt.sync_all()

            data = jt.dump_trace_data()
            jt.clear_trace_data()
            # print_stack_tree(data)
            for k, v in data["execute_op_info"].items():
                for i in v['fused_ops']:
                    if i not in data["node_data"]:
                        assert 0, (i, "not found")

            for k, v in list(data["node_data"].items()):
                if v["attrs"]["name"] == "unname":
                    assert 0
            print(len(data["node_data"]))
            with open(f"{jt.flags.cache_path}/simple_model_train.pkl",
                      "wb") as f:
                pickle.dump(data, f)
Ejemplo n.º 2
0
    def test_resnet_infer_with_feature(self):
        cat_url = "https://ss1.bdstatic.com/70cFuXSh_Q1YnxGkpoWK1HF6hhy/it/u=3782485413,1118109468&fm=26&gp=0.jpg"
        import jittor_utils
        cat_path = f"{jt.flags.cache_path}/cat.jpg"
        print("download")
        jittor_utils.download(cat_url, cat_path)
        with open(cat_path, 'rb') as f:
            img = Image.open(f).convert('RGB')
            img = jt.array(np.array(img))
            print(img.shape, img.dtype)
            img = ((img.float() - 128) / 255).transpose(2, 0, 1)

        with jt.flag_scope(trace_py_var=2, trace_var_data=1):
            img = img[None, ...]

            resnet18 = resnet.Resnet18(pretrained=True)
            x = jt.float32(img)
            y = resnet18(x)
            y.sync()

            data = jt.dump_trace_data()
            jt.clear_trace_data()
            with open(f"{jt.flags.cache_path}/resnet_with_feature.pkl",
                      "wb") as f:
                pickle.dump(data, f)
            for k, v in data["execute_op_info"].items():
                for i in v['fused_ops']:
                    if i not in data["node_data"]:
                        assert 0, (i, "not found")
Ejemplo n.º 3
0
    def test_resnet(self):
        with jt.flag_scope(trace_py_var=2):

            resnet18 = resnet.Resnet18()
            x = jt.float32(np.random.rand(2, 3, 224, 224))
            y = resnet18(x)
            y.sync()

            data = jt.dump_trace_data()
            jt.clear_trace_data()
Ejemplo n.º 4
0
    def test_simple_model(self):
        with jt.flag_scope(trace_py_var=2):

            model = Model(input_size=1)
            batch_size = 10
            x = jt.float32(np.random.rand(batch_size, 1))
            y = model(x)
            y.sync()

            data = jt.dump_trace_data()
            jt.clear_trace_data()
Ejemplo n.º 5
0
    def test_resnet_train(self):
        with jt.flag_scope(trace_py_var=2):

            resnet18 = resnet.Resnet18()
            opt = jt.optim.SGD(resnet18.parameters(), 0.1)
            x = jt.float32(np.random.rand(2, 3, 224, 224))
            y = resnet18(x)

            opt.step(y**2)
            jt.sync_all()

            data = jt.dump_trace_data()
            jt.clear_trace_data()
Ejemplo n.º 6
0
    def test_simple_model(self):
        with jt.flag_scope(trace_py_var=2):

            model = Model(input_size=1)
            batch_size = 10
            x = jt.float32(np.random.rand(batch_size, 1))
            y = model(x)
            y.sync()

            data = jt.dump_trace_data()
            jt.clear_trace_data()
            with open(f"{jt.flags.cache_path}/simple_model.pkl", "wb") as f:
                pickle.dump(data, f)
Ejemplo n.º 7
0
    def test_simple_model_train(self):
        with jt.flag_scope(trace_py_var=2):

            model = Model(input_size=1)
            opt = jt.optim.SGD(model.parameters(), 0.1)

            batch_size = 10
            x = jt.float32(np.random.rand(batch_size, 1))
            y = model(x)
            opt.step(y**2)
            jt.sync_all()

            data = jt.dump_trace_data()
            jt.clear_trace_data()
Ejemplo n.º 8
0
    def test_resnet_infer(self):
        with jt.flag_scope(trace_py_var=2):

            resnet18 = resnet.Resnet18()
            x = jt.float32(np.random.rand(2, 3, 224, 224))
            y = resnet18(x)
            y.sync()

            data = jt.dump_trace_data()
            jt.clear_trace_data()
            with open(f"{jt.flags.cache_path}/resnet.pkl", "wb") as f:
                pickle.dump(data, f)
            for k, v in data["execute_op_info"].items():
                for i in v['fused_ops']:
                    if i not in data["node_data"]:
                        assert 0, (i, "not found")
Ejemplo n.º 9
0
    def test_resnet_trainx(self):
        with jt.flag_scope(trace_py_var=2):

            resnet18 = resnet.Resnet18()
            opt = jt.optim.SGD(resnet18.parameters(), 0.1)
            x = jt.float32(np.random.rand(2, 3, 224, 224))
            y = resnet18(x)

            opt.step(y**2)
            jt.sync_all()

            data = jt.dump_trace_data()
            jt.clear_trace_data()
            with open(f"{jt.flags.cache_path}/resnet_train.pkl", "wb") as f:
                pickle.dump(data, f)
            for k, v in data["execute_op_info"].items():
                for i in v['fused_ops']:
                    if i not in data["node_data"]:
                        assert 0, (i, "not found")
            for k, v in data["node_data"].items():
                if 'name' not in v["attrs"]:
                    print(v)