Esempio n. 1
0
def build_recognizer(cfg, device):
    world_size = du.get_world_size()

    model = registry.RECOGNIZER[cfg.MODEL.RECOGNIZER.TYPE](cfg)

    if cfg.MODEL.NORM.SYNC_BN and world_size > 1:
        logger.info("start sync BN on the process group of {}".format(
            du.LOCAL_RANK_GROUP))
        convert_sync_bn(model, du.LOCAL_PROCESS_GROUP)
    preloaded = cfg.MODEL.RECOGNIZER.PRELOADED
    if preloaded != "":
        logger.info(f'load pretrained: {preloaded}')
        check_pointer = CheckPointer(model)
        check_pointer.load(preloaded, map_location=device)
        logger.info("finish loading model weights")
    if cfg.MODEL.CONV.ADD_BLOCKS is not None:
        assert isinstance(cfg.MODEL.CONV.ADD_BLOCKS, tuple)
        for add_block in cfg.MODEL.CONV.ADD_BLOCKS:
            if add_block == 'RepVGGBlock':
                insert_repvgg_block(model)
            if add_block == 'ACBlock':
                insert_acblock(model)

    model = model.to(device=device)
    if du.get_world_size() > 1:
        model = DDP(model,
                    device_ids=[device],
                    output_device=device,
                    find_unused_parameters=True)

    return model
Esempio n. 2
0
def test_regvgg_recognizer():
    data = torch.randn(1, 3, 224, 224)
    for key in arch_settings.keys():
        print('*' * 10, key)
        model = RepVGGRecognizer(arch=key)
        # print(model)
        outputs = model(data)[KEY_OUTPUT]
        assert outputs.shape == (1, 1000)

        print('insert_regvgg_block -> fuse_regvgg_block')
        insert_repvgg_block(model)
        # print(model)
        model.eval()
        outputs_insert = model(data)[KEY_OUTPUT]
        fuse_repvgg_block(model)
        # print(model)
        model.eval()
        outputs_fuse = model(data)[KEY_OUTPUT]

        # print(outputs_insert)
        # print(outputs_fuse)
        print(torch.sqrt(torch.sum((outputs_insert - outputs_fuse)**2)))
        print(torch.allclose(outputs_insert, outputs_fuse, atol=1e-8))
        assert torch.allclose(outputs_insert, outputs_fuse, atol=1e-8)

        print(
            'insert_regvgg_block -> insert_acblock -> fuse_acblock -> fuse_regvgg_block'
        )
        insert_repvgg_block(model)
        insert_acblock(model)
        # print(model)
        model.eval()
        outputs_insert = model(data)[KEY_OUTPUT]
        fuse_acblock(model)
        fuse_repvgg_block(model)
        # print(model)
        model.eval()
        outputs_fuse = model(data)[KEY_OUTPUT]

        print(torch.sqrt(torch.sum((outputs_insert - outputs_fuse)**2)))
        print(torch.allclose(outputs_insert, outputs_fuse, atol=1e-7))
        assert torch.allclose(outputs_insert, outputs_fuse, atol=1e-7)

        print(
            'insert_acblock -> insert_regvgg_block -> fuse_regvgg_block -> fuse_acblock'
        )
        insert_repvgg_block(model)
        insert_acblock(model)
        # print(model)
        model.eval()
        outputs_insert = model(data)[KEY_OUTPUT]
        fuse_acblock(model)
        fuse_repvgg_block(model)
        # print(model)
        model.eval()
        outputs_fuse = model(data)[KEY_OUTPUT]

        print(torch.sqrt(torch.sum((outputs_insert - outputs_fuse)**2)))
        print(torch.allclose(outputs_insert, outputs_fuse, atol=1e-7))
        assert torch.allclose(outputs_insert, outputs_fuse, atol=1e-7)
Esempio n. 3
0
def test_resnet50_acb():
    model = resnet50()
    model.eval()
    # print(model)

    data = torch.randn(1, 3, 224, 224)
    insert_acblock(model)
    model.eval()
    train_outputs = model(data)
    # print(model)

    fuse_acblock(model, eps=1e-5)
    model.eval()
    eval_outputs = model(data)
    # print(model)

    print(torch.sqrt(torch.sum((train_outputs - eval_outputs)**2)))
    print(torch.allclose(train_outputs, eval_outputs, atol=1e-7))
    assert torch.allclose(train_outputs, eval_outputs, atol=1e-7)
Esempio n. 4
0
def test_acb_helper():
    in_channels = 32
    out_channels = 64
    dilation = 1

    # 下采样 + 分组卷积
    kernel_size = 3
    stride = 2
    padding = 1
    groups = 8

    model = nn.Sequential(
        nn.Sequential(
            nn.Conv2d(in_channels,
                      out_channels,
                      kernel_size,
                      stride=stride,
                      padding=padding,
                      dilation=dilation,
                      groups=groups), nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)), nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True))
    model.eval()
    print(model)

    data = torch.randn(1, in_channels, 56, 56)
    insert_acblock(model)
    model.eval()
    train_outputs = model(data)
    print(model)

    fuse_acblock(model, eps=1e-5)
    model.eval()
    eval_outputs = model(data)
    print(model)

    print(torch.sqrt(torch.sum((train_outputs - eval_outputs)**2)))
    print(torch.allclose(train_outputs, eval_outputs, atol=1e-6))
    assert torch.allclose(train_outputs, eval_outputs, atol=1e-6)
Esempio n. 5
0
def test_regvgg():
    data = torch.randn(1, 3, 224, 224)
    for key in arch_settings.keys():
        print('*' * 10, key)
        cfg.merge_from_file(
            'configs/benchmarks/repvgg/repvgg_b2g4_cifar100_224_e100_sgd_calr.yaml'
        )
        model = RepVGG(cfg)
        # print(model)
        outputs = model(data)[KEY_OUTPUT]
        assert outputs.shape == (1, 100)

        print('insert_regvgg_block -> fuse_regvgg_block')
        insert_repvgg_block(model)
        # print(model)
        model.eval()
        outputs_insert = model(data)[KEY_OUTPUT]
        fuse_repvgg_block(model)
        # print(model)
        model.eval()
        outputs_fuse = model(data)[KEY_OUTPUT]

        # print(outputs_insert)
        # print(outputs_fuse)
        print(torch.sqrt(torch.sum((outputs_insert - outputs_fuse)**2)))
        print(torch.allclose(outputs_insert, outputs_fuse, atol=1e-8))
        assert torch.allclose(outputs_insert, outputs_fuse, atol=1e-8)

        print(
            'insert_regvgg_block -> insert_acblock -> fuse_acblock -> fuse_regvgg_block'
        )
        insert_repvgg_block(model)
        insert_acblock(model)
        # print(model)
        model.eval()
        outputs_insert = model(data)[KEY_OUTPUT]
        fuse_acblock(model)
        fuse_repvgg_block(model)
        # print(model)
        model.eval()
        outputs_fuse = model(data)[KEY_OUTPUT]

        print(torch.sqrt(torch.sum((outputs_insert - outputs_fuse)**2)))
        print(torch.allclose(outputs_insert, outputs_fuse, atol=1e-6))
        assert torch.allclose(outputs_insert, outputs_fuse, atol=1e-6)

        print(
            'insert_acblock -> insert_regvgg_block -> fuse_regvgg_block -> fuse_acblock'
        )
        insert_repvgg_block(model)
        insert_acblock(model)
        # print(model)
        model.eval()
        outputs_insert = model(data)[KEY_OUTPUT]
        fuse_acblock(model)
        fuse_repvgg_block(model)
        # print(model)
        model.eval()
        outputs_fuse = model(data)[KEY_OUTPUT]

        print(torch.sqrt(torch.sum((outputs_insert - outputs_fuse)**2)))
        print(torch.allclose(outputs_insert, outputs_fuse, atol=1e-6))
        assert torch.allclose(outputs_insert, outputs_fuse, atol=1e-6)