示例#1
0
def test():
    #declare the main named dimension variables using tsalib api
    #recall these values anywhere in the program using `get_dim_vars`
    from tsalib import dim_vars
    dim_vars('Batch(b):10 Length(t):100 Hidden(d):1024')

    from tsanley.dynamic import init_analyzer
    init_analyzer(trace_func_names=['f*'], show_updates=True)

    test_func()
示例#2
0
def test2():
    #declare the named dimension variables using the tsalib api
    from tsalib import dim_vars
    dim_vars('Batch(b):10 Length(t):100 Hidden(d):1024')

    # initialize tsanley's dynamic shape analyzer
    from tsanley.dynamic import init_analyzer
    init_analyzer(trace_func_names=['foo'], show_updates=True,
                  debug=False)  #check_tsa=True, debug=False

    test_foo()
示例#3
0
def test_resnet ():
    # declare dim vars: required for checking
    B, C, Ci, Co = dim_vars('Batch(b):10 Channels(c):3 ChannelsIn(ci) ChannelsOut(co)')
    H, W, Ex = dim_vars('Height(h):224 Width(w):224 BlockExpansion(e):1')

    rs18 = resnet18()
    x: 'bchw' = torch.ones(10, 3, 224, 224)
    from tsanley.dynamic import init_analyzer
    #init_analyzer(trace_func_names=['ResNet.forward', 'BasicBlock.forward'])
    init_analyzer(trace_func_names=['ResNet.forward'])
    out = rs18.forward(x)
    print (out.size())
示例#4
0
def setup_named_dims():
    from tsanley.dynamic import init_analyzer
    init_analyzer(trace_func_names=['Net.forward'],
                  show_updates=True)  #check_tsa=True, debug=False
示例#5
0
文件: gnn.py 项目: victor8733/tsanley
def test_gnn ():
    from tsanley.dynamic import init_analyzer
    init_analyzer(['GatedGraphNeuralNetwork.compute_node_representations'])
    main()
示例#6
0
def test_effnet ():
    eff = EffNet()
    x: 'bchw' = torch.ones(B, C, H, W)
    init_analyzer(['EffNet.forward'])
    out = eff.forward(x)
    print (out.size())
示例#7
0
def setup_named_dims():
    from tsanley.dynamic import init_analyzer
    init_analyzer(trace_func_names=['Net.forward', 'AGNNConv.forward'],
                  show_updates=True)