def main():
    model = efficientnetv2_s()

    # option1
    for name, para in model.named_parameters():
        # 除head外,其他权重全部冻结
        if "head" not in name:
            para.requires_grad_(False)
        else:
            print("training {}".format(name))

    complexity = model.complexity(224, 224, 3)
    table = PrettyTable()
    table.field_names = ["params", "freeze-params", "train-params", "FLOPs", "acts"]
    table.add_row([complexity["params"],
                   complexity["freeze"],
                   complexity["params"] - complexity["freeze"],
                   complexity["flops"],
                   complexity["acts"]])
    print(table)

    # option2
    tensor = (torch.rand(1, 3, 224, 224),)
    flops = FlopCountAnalysis(model, tensor)
    print(flops.total())

    print(parameter_count_table(model))
def main():
    # Self-Attention
    a1 = Attention(dim=512, num_heads=1)
    a1.proj = torch.nn.Identity()  # remove Wo

    # Multi-Head Attention
    a2 = Attention(dim=512, num_heads=8)

    # [batch_size, num_tokens, total_embed_dim]
    t = (torch.rand(32, 1024, 512), )

    flops1 = FlopCountAnalysis(a1, t)
    print("Self-Attention FLOPs:", flops1.total())

    flops2 = FlopCountAnalysis(a2, t)
    print("Multi-Head Attention FLOPs:", flops2.total())
def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False, force_cpu=False):
    if force_cpu:
        model = model.to('cpu')
    device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
    example_input = torch.ones((batch_size,) + input_size, device=device, dtype=dtype)
    fca = FlopCountAnalysis(model, example_input)
    aca = ActivationCountAnalysis(model, example_input)
    if detailed:
        fcs = flop_count_str(fca)
        print(fcs)
    return fca.total(), aca.total()