def parity3d_bn(): rank = dist.get_rank() torch.cuda.set_device(rank) torch.manual_seed(rank) x = torch.randn(4, 3, 4, 4, 4).cuda() torch_bn = torch.nn.BatchNorm3d(3).cuda() fs_bn = SyncBatchNorm(3).cuda() check_parity(torch_bn, fs_bn, x)
def parity2d_syncbn(): rank = dist.get_rank() torch.cuda.set_device(rank) torch.manual_seed(rank) x = torch.randn(4, 3, 4, 4).cuda() * rank torch_bn = torch.nn.SyncBatchNorm(3).cuda() fs_bn = SyncBatchNorm(3).cuda() check_parity_ddp(torch_bn, fs_bn, x)
def parity3d_checkpoint_syncbn(): rank = dist.get_rank() torch.cuda.set_device(rank) torch.manual_seed(rank) x = torch.randn(4, 3, 4, 4, 4).cuda() * rank torch_bn = torch.nn.SyncBatchNorm(3).cuda() fs_bn = SyncBatchNorm(3).cuda() fs_bn = checkpoint_wrapper(fs_bn, maintain_forward_counter=True) check_parity_ddp(torch_bn, fs_bn, x)
def parity3d_checkpoint_syncbn_twice(): rank = dist.get_rank() torch.cuda.set_device(rank) torch.manual_seed(rank) x = torch.randn(4, 3, 4, 4, 4).cuda() * rank torch_bn = torch.nn.SyncBatchNorm(3) torch_bn = nn.Sequential(torch_bn, torch_bn).cuda() fs_bn = SyncBatchNorm(3) fs_bn = nn.Sequential(fs_bn, fs_bn).cuda() fs_bn = checkpoint_wrapper(fs_bn) check_parity_ddp(torch_bn, fs_bn, x)
def memory_allocated(): rank = dist.get_rank() torch.cuda.set_device(rank) x = torch.randn(50, 2048, 7, 7).to(rank) torch_bn = torch.nn.SyncBatchNorm(2048).cuda() torch_bn = DDP(torch_bn, device_ids=[rank]) fs_bn = SyncBatchNorm(2048).cuda() fs_bn = DDP(fs_bn, device_ids=[rank]) torch_x = x.detach() torch_x.requires_grad = True fs_x = x.detach() fs_x.requires_grad = True torch.cuda.empty_cache() mem_at_start = torch.cuda.memory_stats()["allocated_bytes.all.current"] torch_y = torch_bn(torch_x) torch.cuda.empty_cache() mem_after_torch = torch.cuda.memory_stats()["allocated_bytes.all.current"] fs_y = fs_bn(fs_x) torch.cuda.empty_cache() mem_final = torch.cuda.memory_stats()["allocated_bytes.all.current"] torch_used = mem_after_torch - mem_at_start fs_used = mem_final - mem_after_torch assert fs_used < (torch_used * 1.01), f"{fs_used} < {torch_used * 1.01}"