def ddp_ref(): # Get a reference model state model = Model() state_before = model.state_dict() # Get reference inputs per rank. world_size = _world_size iterations = _iterations print( f"Getting DDP reference for world_size {world_size} and iterations {iterations}" ) inputs = [[] for i in range(world_size)] for rank in range(world_size): for i in range(iterations): inputs[rank].append(torch.rand(2, 2, 2, 2)) # Run DDP training twice, fp and mp. for precision in ["full", "mixed"]: temp_file_name = tempfile.mkstemp()[1] unused = tempfile.mkstemp()[1] rank_0_output = tempfile.mkstemp()[1] try: fsdp_config = None # This means we use DDP in _test_func. mp.spawn( _test_func, args=( world_size, fsdp_config, None, precision == "mixed", temp_file_name, unused, state_before, inputs, rank_0_output, None, ), nprocs=world_size, join=True, ) if precision == "full": state_after_fp = torch.load(rank_0_output) else: state_after_mp = torch.load(rank_0_output) finally: rmf(temp_file_name) rmf(unused) rmf(rank_0_output) assert state_dict_norm(state_after_fp) != state_dict_norm(state_after_mp) return state_before, inputs, state_after_fp, state_after_mp
def ddp_ref(): # Cover different bias flavors. Use random instead of parameterize them to reduce # the test runtime. Otherwise, we would have covered all cases exhaustively. conv_bias = True if random.randint(0, 1) == 0: conv_bias = False linear_bias = True if random.randint(0, 1) == 0: linear_bias = False # Get a reference model state model = Model(conv_bias, linear_bias) state_before = model.state_dict() # Get reference inputs per rank. world_size = _world_size iterations = _iterations print( f"Getting DDP reference for world_size {world_size} and iterations {iterations}" ) inputs = [[] for i in range(world_size)] for rank in range(world_size): for i in range(iterations): inputs[rank].append(torch.rand(2, 2, 2, 2)) # Run reference DDP training 4 times, fp and mp, sync_bn or not. state_after = {} for precision, sync_bn in product(["full", "mixed"], ["none", "pytorch"]): temp_file_name = tempfile.mkstemp()[1] unused = tempfile.mkstemp()[1] rank_0_output = tempfile.mkstemp()[1] try: fsdp_config = None # This means we use DDP in _distributed_worker. mp.spawn( _distributed_worker, args=( world_size, fsdp_config, None, precision == "mixed", temp_file_name, unused, state_before, inputs, rank_0_output, None, sync_bn, conv_bias, linear_bias, ), nprocs=world_size, join=True, ) state_after[(precision, sync_bn)] = torch.load(rank_0_output) finally: rmf(temp_file_name) rmf(unused) rmf(rank_0_output) # Sanity check DDP's final states. states = list(state_after.values()) for state in states[1:]: assert state_dict_norm(states[0]) != state_dict_norm(state) return state_before, inputs, conv_bias, linear_bias, state_after