Пример #1
0
def test_revert_mmsyncbn():
    if 'SLURM_NTASKS' not in os.environ or int(os.environ['SLURM_NTASKS']) < 2:
        print('Must run on slurm with more than 1 process!\n'
              'srun -p test --gres=gpu:2 -n2')
        return
    rank = int(os.environ['SLURM_PROCID'])
    world_size = int(os.environ['SLURM_NTASKS'])
    local_rank = int(os.environ['SLURM_LOCALID'])
    node_list = str(os.environ['SLURM_NODELIST'])

    node_parts = re.findall('[0-9]+', node_list)
    os.environ['MASTER_ADDR'] = (f'{node_parts[1]}.{node_parts[2]}' +
                                 f'.{node_parts[3]}.{node_parts[4]}')
    os.environ['MASTER_PORT'] = '12341'
    os.environ['WORLD_SIZE'] = str(world_size)
    os.environ['RANK'] = str(rank)

    dist.init_process_group('nccl')
    torch.cuda.set_device(local_rank)
    x = torch.randn(1, 3, 10, 10).cuda()
    dist.broadcast(x, src=0)
    conv = ConvModule(3, 8, 2, norm_cfg=dict(type='MMSyncBN')).cuda()
    conv.eval()
    y_mmsyncbn = conv(x).detach().cpu().numpy()
    conv = revert_sync_batchnorm(conv)
    y_bn = conv(x).detach().cpu().numpy()
    assert np.all(np.isclose(y_bn, y_mmsyncbn, 1e-3))
    conv, x = conv.to('cpu'), x.to('cpu')
    y_bn_cpu = conv(x).detach().numpy()
    assert np.all(np.isclose(y_bn, y_bn_cpu, 1e-3))
Пример #2
0
def test_revert_sync_batchnorm():
    conv_syncbn = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')).to('cpu')
    conv_syncbn.train()
    x = torch.randn(1, 3, 10, 10)
    # Will raise an ValueError saying SyncBN does not run on CPU
    with pytest.raises(ValueError):
        y = conv_syncbn(x)
    conv_bn = revert_sync_batchnorm(conv_syncbn)
    y = conv_bn(x)
    assert y.shape == (1, 8, 9, 9)
    assert conv_bn.training == conv_syncbn.training
    conv_syncbn.eval()
    conv_bn = revert_sync_batchnorm(conv_syncbn)
    assert conv_bn.training == conv_syncbn.training