Beispiel #1
0
def get_flops(model, input_shape):
    flops, params = get_model_complexity_info(model,
                                              input_shape,
                                              as_strings=False)
    if 'pvt' in model.name:
        _, H, W = input_shape
        if 'li' in model.name:  # calculate flops of PVTv2_li
            stage1 = li_sra_flops(H // 4, W // 4,
                                  model.block1[0].attn.dim) * len(model.block1)
            stage2 = li_sra_flops(H // 8, W // 8,
                                  model.block2[0].attn.dim) * len(model.block2)
            stage3 = li_sra_flops(H // 16, W // 16,
                                  model.block3[0].attn.dim) * len(model.block3)
            stage4 = li_sra_flops(H // 32, W // 32,
                                  model.block4[0].attn.dim) * len(model.block4)
        else:  # calculate flops of PVT/PVTv2
            stage1 = sra_flops(H // 4, W // 4, model.block1[0].attn.sr_ratio,
                               model.block1[0].attn.dim) * len(model.block1)
            stage2 = sra_flops(H // 8, W // 8, model.block2[0].attn.sr_ratio,
                               model.block2[0].attn.dim) * len(model.block2)
            stage3 = sra_flops(H // 16, W // 16, model.block3[0].attn.sr_ratio,
                               model.block3[0].attn.dim) * len(model.block3)
            stage4 = sra_flops(H // 32, W // 32, model.block4[0].attn.sr_ratio,
                               model.block4[0].attn.dim) * len(model.block4)
        flops += stage1 + stage2 + stage3 + stage4
    return flops_to_string(flops), params_to_string(params)
Beispiel #2
0
def get_vit_flops(net, input_shape, patch_size):
    flops, params = get_model_complexity_info(net,
                                              input_shape,
                                              as_strings=False)
    _, H, W = input_shape
    stage = mha_flops(H // patch_size, W // patch_size, net.blocks[0].attn.dim,
                      net.blocks[0].attn.num_heads) * len(net.blocks)
    flops += stage
    return flops_to_string(flops), params_to_string(params)
Beispiel #3
0
def test_params_to_string():
    num_params = 3.21 * 10.**7
    assert params_to_string(num_params) == '32.1 M'
    num_params = 4.56 * 10.**5
    assert params_to_string(num_params) == '456.0 k'
    num_params = 7.89 * 10.**2
    assert params_to_string(num_params) == '789.0'

    num_params = 6.54321 * 10.**7
    assert params_to_string(num_params, 'M') == '65.43 M'
    assert params_to_string(num_params, 'K') == '65432.1 K'
    assert params_to_string(num_params, '') == '65432100.0'
    assert params_to_string(num_params, precision=4) == '65.4321 M'
Beispiel #4
0
def get_pvt_flops(net, input_shape):
    flops, params = get_model_complexity_info(net,
                                              input_shape,
                                              as_strings=False)
    _, H, W = input_shape
    stage1 = sra_flops(H // 4, W // 4, net.block1[0].attn.sr_ratio,
                       net.block1[0].attn.dim,
                       net.block1[0].attn.num_heads) * len(net.block1)
    stage2 = sra_flops(H // 8, W // 8, net.block2[0].attn.sr_ratio,
                       net.block2[0].attn.dim,
                       net.block2[0].attn.num_heads) * len(net.block2)
    stage3 = sra_flops(H // 16, W // 16, net.block3[0].attn.sr_ratio,
                       net.block3[0].attn.dim,
                       net.block3[0].attn.num_heads) * len(net.block3)
    stage4 = sra_flops(H // 32, W // 32, net.block4[0].attn.sr_ratio,
                       net.block4[0].attn.dim,
                       net.block4[0].attn.num_heads) * len(net.block4)
    print(stage1 + stage2 + stage3 + stage4)
    flops += stage1 + stage2 + stage3 + stage4
    return flops_to_string(flops), params_to_string(params)