def get_flops(model, input_shape): flops, params = get_model_complexity_info(model, input_shape, as_strings=False) if 'pvt' in model.name: _, H, W = input_shape if 'li' in model.name: # calculate flops of PVTv2_li stage1 = li_sra_flops(H // 4, W // 4, model.block1[0].attn.dim) * len(model.block1) stage2 = li_sra_flops(H // 8, W // 8, model.block2[0].attn.dim) * len(model.block2) stage3 = li_sra_flops(H // 16, W // 16, model.block3[0].attn.dim) * len(model.block3) stage4 = li_sra_flops(H // 32, W // 32, model.block4[0].attn.dim) * len(model.block4) else: # calculate flops of PVT/PVTv2 stage1 = sra_flops(H // 4, W // 4, model.block1[0].attn.sr_ratio, model.block1[0].attn.dim) * len(model.block1) stage2 = sra_flops(H // 8, W // 8, model.block2[0].attn.sr_ratio, model.block2[0].attn.dim) * len(model.block2) stage3 = sra_flops(H // 16, W // 16, model.block3[0].attn.sr_ratio, model.block3[0].attn.dim) * len(model.block3) stage4 = sra_flops(H // 32, W // 32, model.block4[0].attn.sr_ratio, model.block4[0].attn.dim) * len(model.block4) flops += stage1 + stage2 + stage3 + stage4 return flops_to_string(flops), params_to_string(params)
def get_vit_flops(net, input_shape, patch_size): flops, params = get_model_complexity_info(net, input_shape, as_strings=False) _, H, W = input_shape stage = mha_flops(H // patch_size, W // patch_size, net.blocks[0].attn.dim, net.blocks[0].attn.num_heads) * len(net.blocks) flops += stage return flops_to_string(flops), params_to_string(params)
def test_params_to_string(): num_params = 3.21 * 10.**7 assert params_to_string(num_params) == '32.1 M' num_params = 4.56 * 10.**5 assert params_to_string(num_params) == '456.0 k' num_params = 7.89 * 10.**2 assert params_to_string(num_params) == '789.0' num_params = 6.54321 * 10.**7 assert params_to_string(num_params, 'M') == '65.43 M' assert params_to_string(num_params, 'K') == '65432.1 K' assert params_to_string(num_params, '') == '65432100.0' assert params_to_string(num_params, precision=4) == '65.4321 M'
def get_pvt_flops(net, input_shape): flops, params = get_model_complexity_info(net, input_shape, as_strings=False) _, H, W = input_shape stage1 = sra_flops(H // 4, W // 4, net.block1[0].attn.sr_ratio, net.block1[0].attn.dim, net.block1[0].attn.num_heads) * len(net.block1) stage2 = sra_flops(H // 8, W // 8, net.block2[0].attn.sr_ratio, net.block2[0].attn.dim, net.block2[0].attn.num_heads) * len(net.block2) stage3 = sra_flops(H // 16, W // 16, net.block3[0].attn.sr_ratio, net.block3[0].attn.dim, net.block3[0].attn.num_heads) * len(net.block3) stage4 = sra_flops(H // 32, W // 32, net.block4[0].attn.sr_ratio, net.block4[0].attn.dim, net.block4[0].attn.num_heads) * len(net.block4) print(stage1 + stage2 + stage3 + stage4) flops += stage1 + stage2 + stage3 + stage4 return flops_to_string(flops), params_to_string(params)