def build_recognizer(cfg, device): world_size = du.get_world_size() model = registry.RECOGNIZER[cfg.MODEL.RECOGNIZER.TYPE](cfg) if cfg.MODEL.NORM.SYNC_BN and world_size > 1: logger.info("start sync BN on the process group of {}".format( du.LOCAL_RANK_GROUP)) convert_sync_bn(model, du.LOCAL_PROCESS_GROUP) preloaded = cfg.MODEL.RECOGNIZER.PRELOADED if preloaded != "": logger.info(f'load pretrained: {preloaded}') check_pointer = CheckPointer(model) check_pointer.load(preloaded, map_location=device) logger.info("finish loading model weights") if cfg.MODEL.CONV.ADD_BLOCKS is not None: assert isinstance(cfg.MODEL.CONV.ADD_BLOCKS, tuple) for add_block in cfg.MODEL.CONV.ADD_BLOCKS: if add_block == 'RepVGGBlock': insert_repvgg_block(model) if add_block == 'ACBlock': insert_acblock(model) model = model.to(device=device) if du.get_world_size() > 1: model = DDP(model, device_ids=[device], output_device=device, find_unused_parameters=True) return model
def test_regvgg(): model = RepVGG() model.eval() print(model) data = torch.randn(1, 3, 224, 224) insert_repvgg_block(model) model.eval() train_outputs = model(data)[KEY_OUTPUT] print(model) fuse_repvgg_block(model) model.eval() eval_outputs = model(data)[KEY_OUTPUT] print(model) print(torch.sqrt(torch.sum((train_outputs - eval_outputs)**2))) print(torch.allclose(train_outputs, eval_outputs, atol=1e-8)) assert torch.allclose(train_outputs, eval_outputs, atol=1e-8)
def test_regvgg(): cfg.merge_from_file('configs/benchmarks/repvgg/repvgg_b2g4_cifar100_224_e100_sgd_calr.yaml') model = RepVGG(cfg) model.eval() print(model) data = torch.randn(1, 3, 224, 224) insert_repvgg_block(model) model.eval() train_outputs = model(data)[KEY_OUTPUT] print(model) fuse_repvgg_block(model) model.eval() eval_outputs = model(data)[KEY_OUTPUT] print(model) print(torch.sqrt(torch.sum((train_outputs - eval_outputs) ** 2))) print(torch.allclose(train_outputs, eval_outputs, atol=1e-8)) assert torch.allclose(train_outputs, eval_outputs, atol=1e-8)
def test_regvgg_recognizer(): data = torch.randn(1, 3, 224, 224) for key in arch_settings.keys(): print('*' * 10, key) model = RepVGGRecognizer(arch=key) # print(model) outputs = model(data)[KEY_OUTPUT] assert outputs.shape == (1, 1000) print('insert_regvgg_block -> fuse_regvgg_block') insert_repvgg_block(model) # print(model) model.eval() outputs_insert = model(data)[KEY_OUTPUT] fuse_repvgg_block(model) # print(model) model.eval() outputs_fuse = model(data)[KEY_OUTPUT] # print(outputs_insert) # print(outputs_fuse) print(torch.sqrt(torch.sum((outputs_insert - outputs_fuse)**2))) print(torch.allclose(outputs_insert, outputs_fuse, atol=1e-8)) assert torch.allclose(outputs_insert, outputs_fuse, atol=1e-8) print( 'insert_regvgg_block -> insert_acblock -> fuse_acblock -> fuse_regvgg_block' ) insert_repvgg_block(model) insert_acblock(model) # print(model) model.eval() outputs_insert = model(data)[KEY_OUTPUT] fuse_acblock(model) fuse_repvgg_block(model) # print(model) model.eval() outputs_fuse = model(data)[KEY_OUTPUT] print(torch.sqrt(torch.sum((outputs_insert - outputs_fuse)**2))) print(torch.allclose(outputs_insert, outputs_fuse, atol=1e-7)) assert torch.allclose(outputs_insert, outputs_fuse, atol=1e-7) print( 'insert_acblock -> insert_regvgg_block -> fuse_regvgg_block -> fuse_acblock' ) insert_repvgg_block(model) insert_acblock(model) # print(model) model.eval() outputs_insert = model(data)[KEY_OUTPUT] fuse_acblock(model) fuse_repvgg_block(model) # print(model) model.eval() outputs_fuse = model(data)[KEY_OUTPUT] print(torch.sqrt(torch.sum((outputs_insert - outputs_fuse)**2))) print(torch.allclose(outputs_insert, outputs_fuse, atol=1e-7)) assert torch.allclose(outputs_insert, outputs_fuse, atol=1e-7)
def test_conv_helper(): in_channels = 32 out_channels = 64 dilation = 1 # 下采样 + 分组卷积 kernel_size = 3 stride = 2 padding = 1 groups = 8 model = nn.Sequential( nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)) model.eval() print(model) data = torch.randn(1, in_channels, 56, 56) insert_repvgg_block(model) model.eval() train_outputs = model(data) print(model) fuse_repvgg_block(model) model.eval() eval_outputs = model(data) print(model) print(torch.sqrt(torch.sum((train_outputs - eval_outputs)**2))) print(torch.allclose(train_outputs, eval_outputs, atol=1e-8)) assert torch.allclose(train_outputs, eval_outputs, atol=1e-8)
def test_regvgg(): data = torch.randn(1, 3, 224, 224) for key in arch_settings.keys(): print('*' * 10, key) cfg.merge_from_file( 'configs/benchmarks/repvgg/repvgg_b2g4_cifar100_224_e100_sgd_calr.yaml' ) model = RepVGG(cfg) # print(model) outputs = model(data)[KEY_OUTPUT] assert outputs.shape == (1, 100) print('insert_regvgg_block -> fuse_regvgg_block') insert_repvgg_block(model) # print(model) model.eval() outputs_insert = model(data)[KEY_OUTPUT] fuse_repvgg_block(model) # print(model) model.eval() outputs_fuse = model(data)[KEY_OUTPUT] # print(outputs_insert) # print(outputs_fuse) print(torch.sqrt(torch.sum((outputs_insert - outputs_fuse)**2))) print(torch.allclose(outputs_insert, outputs_fuse, atol=1e-8)) assert torch.allclose(outputs_insert, outputs_fuse, atol=1e-8) print( 'insert_regvgg_block -> insert_acblock -> fuse_acblock -> fuse_regvgg_block' ) insert_repvgg_block(model) insert_acblock(model) # print(model) model.eval() outputs_insert = model(data)[KEY_OUTPUT] fuse_acblock(model) fuse_repvgg_block(model) # print(model) model.eval() outputs_fuse = model(data)[KEY_OUTPUT] print(torch.sqrt(torch.sum((outputs_insert - outputs_fuse)**2))) print(torch.allclose(outputs_insert, outputs_fuse, atol=1e-6)) assert torch.allclose(outputs_insert, outputs_fuse, atol=1e-6) print( 'insert_acblock -> insert_regvgg_block -> fuse_regvgg_block -> fuse_acblock' ) insert_repvgg_block(model) insert_acblock(model) # print(model) model.eval() outputs_insert = model(data)[KEY_OUTPUT] fuse_acblock(model) fuse_repvgg_block(model) # print(model) model.eval() outputs_fuse = model(data)[KEY_OUTPUT] print(torch.sqrt(torch.sum((outputs_insert - outputs_fuse)**2))) print(torch.allclose(outputs_insert, outputs_fuse, atol=1e-6)) assert torch.allclose(outputs_insert, outputs_fuse, atol=1e-6)