예제 #1
0
def test_regvgg_recognizer():
    data = torch.randn(1, 3, 224, 224)
    for key in arch_settings.keys():
        print('*' * 10, key)
        model = RepVGGRecognizer(arch=key)
        # print(model)
        outputs = model(data)[KEY_OUTPUT]
        assert outputs.shape == (1, 1000)

        print('insert_regvgg_block -> fuse_regvgg_block')
        insert_repvgg_block(model)
        # print(model)
        model.eval()
        outputs_insert = model(data)[KEY_OUTPUT]
        fuse_repvgg_block(model)
        # print(model)
        model.eval()
        outputs_fuse = model(data)[KEY_OUTPUT]

        # print(outputs_insert)
        # print(outputs_fuse)
        print(torch.sqrt(torch.sum((outputs_insert - outputs_fuse)**2)))
        print(torch.allclose(outputs_insert, outputs_fuse, atol=1e-8))
        assert torch.allclose(outputs_insert, outputs_fuse, atol=1e-8)

        print(
            'insert_regvgg_block -> insert_acblock -> fuse_acblock -> fuse_regvgg_block'
        )
        insert_repvgg_block(model)
        insert_acblock(model)
        # print(model)
        model.eval()
        outputs_insert = model(data)[KEY_OUTPUT]
        fuse_acblock(model)
        fuse_repvgg_block(model)
        # print(model)
        model.eval()
        outputs_fuse = model(data)[KEY_OUTPUT]

        print(torch.sqrt(torch.sum((outputs_insert - outputs_fuse)**2)))
        print(torch.allclose(outputs_insert, outputs_fuse, atol=1e-7))
        assert torch.allclose(outputs_insert, outputs_fuse, atol=1e-7)

        print(
            'insert_acblock -> insert_regvgg_block -> fuse_regvgg_block -> fuse_acblock'
        )
        insert_repvgg_block(model)
        insert_acblock(model)
        # print(model)
        model.eval()
        outputs_insert = model(data)[KEY_OUTPUT]
        fuse_acblock(model)
        fuse_repvgg_block(model)
        # print(model)
        model.eval()
        outputs_fuse = model(data)[KEY_OUTPUT]

        print(torch.sqrt(torch.sum((outputs_insert - outputs_fuse)**2)))
        print(torch.allclose(outputs_insert, outputs_fuse, atol=1e-7))
        assert torch.allclose(outputs_insert, outputs_fuse, atol=1e-7)
예제 #2
0
def test_resnet50_acb():
    model = resnet50()
    model.eval()
    # print(model)

    data = torch.randn(1, 3, 224, 224)
    insert_acblock(model)
    model.eval()
    train_outputs = model(data)
    # print(model)

    fuse_acblock(model, eps=1e-5)
    model.eval()
    eval_outputs = model(data)
    # print(model)

    print(torch.sqrt(torch.sum((train_outputs - eval_outputs)**2)))
    print(torch.allclose(train_outputs, eval_outputs, atol=1e-7))
    assert torch.allclose(train_outputs, eval_outputs, atol=1e-7)
예제 #3
0
def test_acb_helper():
    in_channels = 32
    out_channels = 64
    dilation = 1

    # 下采样 + 分组卷积
    kernel_size = 3
    stride = 2
    padding = 1
    groups = 8

    model = nn.Sequential(
        nn.Sequential(
            nn.Conv2d(in_channels,
                      out_channels,
                      kernel_size,
                      stride=stride,
                      padding=padding,
                      dilation=dilation,
                      groups=groups), nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)), nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True))
    model.eval()
    print(model)

    data = torch.randn(1, in_channels, 56, 56)
    insert_acblock(model)
    model.eval()
    train_outputs = model(data)
    print(model)

    fuse_acblock(model, eps=1e-5)
    model.eval()
    eval_outputs = model(data)
    print(model)

    print(torch.sqrt(torch.sum((train_outputs - eval_outputs)**2)))
    print(torch.allclose(train_outputs, eval_outputs, atol=1e-6))
    assert torch.allclose(train_outputs, eval_outputs, atol=1e-6)
예제 #4
0
def test_config_file():
    data = torch.randn(3, 3, 224, 224)

    print('repvgg_b2g4_custom_cifar100_224_e100_sgd')
    config_file = "configs/benchmarks/repvgg/repvgg_b2g4_cifar100_224_e100_sgd_calr.yaml"
    cfg.merge_from_file(config_file)

    device = torch.device('cpu')
    model = build_recognizer(cfg, device)
    print(model)
    outputs = model(data)[KEY_OUTPUT]

    assert outputs.shape == (3, 100)
    fuse_repvgg_block(model)
    print(model)
    outputs = model(data)[KEY_OUTPUT]

    assert outputs.shape == (3, 100)

    print('repvgg_b2g4_acb_custom_cifar100_224_e100_sgd')
    config_file = "configs/benchmarks/repvgg/repvgg_b2g4_acb_cifar100_224_e100_sgd_calr.yaml"
    cfg.merge_from_file(config_file)

    device = torch.device('cpu')
    model = build_recognizer(cfg, device)
    print(model)
    outputs = model(data)[KEY_OUTPUT]

    assert outputs.shape == (3, 100)
    # 注意:如果在RepVGG中嵌入了ACBlock,融合时应该先acb再regvgg
    fuse_acblock(model)
    print(model)
    fuse_repvgg_block(model)
    print(model)
    outputs = model(data)[KEY_OUTPUT]

    assert outputs.shape == (3, 100)

    print('acb_repvgg_b2g4_custom_cifar100_224_e100_sgd')
    config_file = "configs/benchmarks/repvgg/acb_repvgg_b2g4_cifar100_224_e100_sgd_calr.yaml"
    cfg.merge_from_file(config_file)

    device = torch.device('cpu')
    model = build_recognizer(cfg, device)
    print(model)
    outputs = model(data)[KEY_OUTPUT]

    assert outputs.shape == (3, 100)
    # 注意:如果先嵌入ACBlock再嵌入RepVGGBlock,那么融合时应该先repvgg_block再acblock
    fuse_repvgg_block(model)
    print(model)
    fuse_acblock(model)
    print(model)
    outputs = model(data)[KEY_OUTPUT]

    assert outputs.shape == (3, 100)

    print('rxtd50_32x4d_acb_rvb_custom_cifar100_224_e100_sgd')
    config_file = "configs/benchmarks/repvgg/rxtd50_32x4d_acb_rvb_cifar100_224_e100_sgd_calr.yaml"
    cfg.merge_from_file(config_file)

    device = torch.device('cpu')
    model = build_recognizer(cfg, device)
    print(model)
    outputs = model(data)[KEY_OUTPUT]

    assert outputs.shape == (3, 100)
    # 注意:如果先嵌入ACBlock再嵌入RepVGGBlock,那么融合时应该先repvgg_block再acblock
    fuse_repvgg_block(model)
    print(model)
    fuse_acblock(model)
    print(model)
    outputs = model(data)[KEY_OUTPUT]

    assert outputs.shape == (3, 100)

    print('rxtd50_32x4d_rvb_acb_custom_cifar100_224_e100_sgd')
    config_file = "configs/benchmarks/repvgg/rxtd50_32x4d_rvb_acb_cifar100_224_e100_sgd_calr.yaml"
    cfg.merge_from_file(config_file)

    device = torch.device('cpu')
    model = build_recognizer(cfg, device)
    print(model)
    outputs = model(data)[KEY_OUTPUT]

    assert outputs.shape == (3, 100)
    # 注意:如果先嵌入RepVGGBlock再嵌入ACBlock,那么逆序融合
    fuse_acblock(model)
    print(model)
    fuse_repvgg_block(model)
    print(model)
    outputs = model(data)[KEY_OUTPUT]

    assert outputs.shape == (3, 100)
예제 #5
0
def test_regvgg():
    data = torch.randn(1, 3, 224, 224)
    for key in arch_settings.keys():
        print('*' * 10, key)
        cfg.merge_from_file(
            'configs/benchmarks/repvgg/repvgg_b2g4_cifar100_224_e100_sgd_calr.yaml'
        )
        model = RepVGG(cfg)
        # print(model)
        outputs = model(data)[KEY_OUTPUT]
        assert outputs.shape == (1, 100)

        print('insert_regvgg_block -> fuse_regvgg_block')
        insert_repvgg_block(model)
        # print(model)
        model.eval()
        outputs_insert = model(data)[KEY_OUTPUT]
        fuse_repvgg_block(model)
        # print(model)
        model.eval()
        outputs_fuse = model(data)[KEY_OUTPUT]

        # print(outputs_insert)
        # print(outputs_fuse)
        print(torch.sqrt(torch.sum((outputs_insert - outputs_fuse)**2)))
        print(torch.allclose(outputs_insert, outputs_fuse, atol=1e-8))
        assert torch.allclose(outputs_insert, outputs_fuse, atol=1e-8)

        print(
            'insert_regvgg_block -> insert_acblock -> fuse_acblock -> fuse_regvgg_block'
        )
        insert_repvgg_block(model)
        insert_acblock(model)
        # print(model)
        model.eval()
        outputs_insert = model(data)[KEY_OUTPUT]
        fuse_acblock(model)
        fuse_repvgg_block(model)
        # print(model)
        model.eval()
        outputs_fuse = model(data)[KEY_OUTPUT]

        print(torch.sqrt(torch.sum((outputs_insert - outputs_fuse)**2)))
        print(torch.allclose(outputs_insert, outputs_fuse, atol=1e-6))
        assert torch.allclose(outputs_insert, outputs_fuse, atol=1e-6)

        print(
            'insert_acblock -> insert_regvgg_block -> fuse_regvgg_block -> fuse_acblock'
        )
        insert_repvgg_block(model)
        insert_acblock(model)
        # print(model)
        model.eval()
        outputs_insert = model(data)[KEY_OUTPUT]
        fuse_acblock(model)
        fuse_repvgg_block(model)
        # print(model)
        model.eval()
        outputs_fuse = model(data)[KEY_OUTPUT]

        print(torch.sqrt(torch.sum((outputs_insert - outputs_fuse)**2)))
        print(torch.allclose(outputs_insert, outputs_fuse, atol=1e-6))
        assert torch.allclose(outputs_insert, outputs_fuse, atol=1e-6)