Exemplo n.º 1
0
def parity3d_bn():
    rank = dist.get_rank()
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)

    x = torch.randn(4, 3, 4, 4, 4).cuda()
    torch_bn = torch.nn.BatchNorm3d(3).cuda()
    fs_bn = SyncBatchNorm(3).cuda()
    check_parity(torch_bn, fs_bn, x)
Exemplo n.º 2
0
def parity2d_syncbn():
    rank = dist.get_rank()
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)

    x = torch.randn(4, 3, 4, 4).cuda() * rank
    torch_bn = torch.nn.SyncBatchNorm(3).cuda()
    fs_bn = SyncBatchNorm(3).cuda()
    check_parity_ddp(torch_bn, fs_bn, x)
Exemplo n.º 3
0
def parity3d_checkpoint_syncbn():
    rank = dist.get_rank()
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)

    x = torch.randn(4, 3, 4, 4, 4).cuda() * rank
    torch_bn = torch.nn.SyncBatchNorm(3).cuda()
    fs_bn = SyncBatchNorm(3).cuda()
    fs_bn = checkpoint_wrapper(fs_bn, maintain_forward_counter=True)
    check_parity_ddp(torch_bn, fs_bn, x)
Exemplo n.º 4
0
def parity3d_checkpoint_syncbn_twice():
    rank = dist.get_rank()
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)

    x = torch.randn(4, 3, 4, 4, 4).cuda() * rank
    torch_bn = torch.nn.SyncBatchNorm(3)
    torch_bn = nn.Sequential(torch_bn, torch_bn).cuda()
    fs_bn = SyncBatchNorm(3)
    fs_bn = nn.Sequential(fs_bn, fs_bn).cuda()
    fs_bn = checkpoint_wrapper(fs_bn)
    check_parity_ddp(torch_bn, fs_bn, x)
Exemplo n.º 5
0
def memory_allocated():
    rank = dist.get_rank()
    torch.cuda.set_device(rank)

    x = torch.randn(50, 2048, 7, 7).to(rank)
    torch_bn = torch.nn.SyncBatchNorm(2048).cuda()
    torch_bn = DDP(torch_bn, device_ids=[rank])
    fs_bn = SyncBatchNorm(2048).cuda()
    fs_bn = DDP(fs_bn, device_ids=[rank])
    torch_x = x.detach()
    torch_x.requires_grad = True
    fs_x = x.detach()
    fs_x.requires_grad = True
    torch.cuda.empty_cache()
    mem_at_start = torch.cuda.memory_stats()["allocated_bytes.all.current"]
    torch_y = torch_bn(torch_x)
    torch.cuda.empty_cache()
    mem_after_torch = torch.cuda.memory_stats()["allocated_bytes.all.current"]
    fs_y = fs_bn(fs_x)
    torch.cuda.empty_cache()
    mem_final = torch.cuda.memory_stats()["allocated_bytes.all.current"]
    torch_used = mem_after_torch - mem_at_start
    fs_used = mem_final - mem_after_torch
    assert fs_used < (torch_used * 1.01), f"{fs_used} < {torch_used * 1.01}"