self.weight = nn.Parameter(torch.ones(channels))
        self.bias = nn.Parameter(torch.zeros(channels))
        self._pos = 0

    def forward(self, x):
        ret = self[self._pos](x)
        self._pos = (self._pos + 1) % len(self)

        w = self.weight.reshape(1, -1, 1, 1)
        b = self.bias.reshape(1, -1, 1, 1)
        return ret * w + b


if __name__ == "__main__":
    checkpoint = sys.argv[1]
    cfg = LazyConfig.load_rel("./configs/retinanet_SyncBNhead.py")
    model = cfg.model
    model.head.norm = lambda c: CycleBatchNormList(len(model.head_in_features),
                                                   c)
    model = instantiate(model)
    model.cuda()
    DetectionCheckpointer(model).load(checkpoint)

    cfg.dataloader.train.total_batch_size = 8
    logger.info("Running PreciseBN ...")
    with EventStorage(), torch.no_grad():
        update_bn_stats(model, instantiate(cfg.dataloader.train), 500)

    logger.info("Running evaluation ...")
    inference_on_dataset(model, instantiate(cfg.dataloader.test),
                         instantiate(cfg.dataloader.evaluator))
예제 #2
0
# Copyright (c) Facebook, Inc. and its affiliates.
from detectron2.config import LazyConfig

# equivalent to relative import
dir1a_str, dir1a_dict = LazyConfig.load_rel("dir1_a.py",
                                            ("dir1a_str", "dir1a_dict"))

dir1b_str = dir1a_str + "_from_b"
dir1b_dict = dir1a_dict

# Every import is a reload: not modified by other config files
assert dir1a_dict.a == 1