def test_ddp_model_precise_bn(): # test DDP model test_bigger_dataset = BiggerDataset() loader = DataLoader(test_bigger_dataset, batch_size=2) precise_bn_hook = PreciseBNHook(loader, num_iters=5) assert precise_bn_hook.num_iters == 5 assert precise_bn_hook.interval == 1 model = ExampleModel() model = MMDistributedDataParallel( model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=True) runner = EpochBasedRunner(model=model, batch_processor=None, optimizer=optimizer, logger=logger) runner.register_hook(precise_bn_hook) runner.run([loader], [('train', 1)], 1)
def test_precise_bn(): with pytest.raises(TypeError): # `data_loader` must be a Pytorch DataLoader test_dataset = ExampleModel() data_loader = DataLoader(test_dataset, batch_size=2, sampler=None, num_workers=0, shuffle=False) PreciseBNHook('data_loader') optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) test_dataset = ExampleDataset() loader = DataLoader(test_dataset, batch_size=2) model = ExampleModel() optimizer = build_optimizer(model, optimizer_cfg) data_loader = DataLoader(test_dataset, batch_size=2) precise_bn_loader = copy.deepcopy(data_loader) logger = get_logger('precise_bn') runner = EpochBasedRunner(model=model, batch_processor=None, optimizer=optimizer, logger=logger) with pytest.raises(AssertionError): # num_iters should be no larget than total # iters precise_bn_hook = PreciseBNHook(precise_bn_loader, num_iters=5) runner.register_hook(precise_bn_hook) runner.run([loader], [('train', 1)], 1) # test non-DDP model test_bigger_dataset = BiggerDataset() loader = DataLoader(test_bigger_dataset, batch_size=2) precise_bn_hook = PreciseBNHook(loader, num_iters=5) assert precise_bn_hook.num_iters == 5 assert precise_bn_hook.interval == 1 runner = EpochBasedRunner(model=model, batch_processor=None, optimizer=optimizer, logger=logger) runner.register_hook(precise_bn_hook) runner.run([loader], [('train', 1)], 1) # test model w/ gn layer loader = DataLoader(test_bigger_dataset, batch_size=2) precise_bn_hook = PreciseBNHook(loader, num_iters=5) assert precise_bn_hook.num_iters == 5 assert precise_bn_hook.interval == 1 model = GNExampleModel() runner = EpochBasedRunner(model=model, batch_processor=None, optimizer=optimizer, logger=logger) runner.register_hook(precise_bn_hook) runner.run([loader], [('train', 1)], 1) # test model without bn layer loader = DataLoader(test_bigger_dataset, batch_size=2) precise_bn_hook = PreciseBNHook(loader, num_iters=5) assert precise_bn_hook.num_iters == 5 assert precise_bn_hook.interval == 1 model = NoBNExampleModel() runner = EpochBasedRunner(model=model, batch_processor=None, optimizer=optimizer, logger=logger) runner.register_hook(precise_bn_hook) runner.run([loader], [('train', 1)], 1) # test how precise it is loader = DataLoader(test_bigger_dataset, batch_size=2) precise_bn_hook = PreciseBNHook(loader, num_iters=6) # run all assert precise_bn_hook.num_iters == 6 assert precise_bn_hook.interval == 1 model = SingleBNModel() runner = EpochBasedRunner(model=model, batch_processor=None, optimizer=optimizer, logger=logger) runner.register_hook(precise_bn_hook) runner.run([loader], [('train', 1)], 1) imgs_list = list() for _, data in enumerate(loader): imgs_list.append(np.array(data['imgs'])) mean = np.mean([np.mean(batch) for batch in imgs_list]) # bassel correction used in Pytorch, therefore ddof=1 var = np.mean([np.var(batch, ddof=1) for batch in imgs_list]) assert np.equal(mean, np.array(model.bn.running_mean)) assert np.equal(var, np.array(model.bn.running_var)) @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires CUDA support') def test_ddp_model_precise_bn(): # test DDP model test_bigger_dataset = BiggerDataset() loader = DataLoader(test_bigger_dataset, batch_size=2) precise_bn_hook = PreciseBNHook(loader, num_iters=5) assert precise_bn_hook.num_iters == 5 assert precise_bn_hook.interval == 1 model = ExampleModel() model = MMDistributedDataParallel( model.cuda(), device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=True) runner = EpochBasedRunner(model=model, batch_processor=None, optimizer=optimizer, logger=logger) runner.register_hook(precise_bn_hook) runner.run([loader], [('train', 1)], 1)