def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size): # Check that the wrapped module can change devices dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) torch.cuda.set_device(rank) model = Sequential(Linear(2, 3), Linear( 3, 3)).cpu() # not device on purpose, test changing it after the fact optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer, sync_models_at_startup=False, reduce_buffer_size=reduce_buffer_size) try: ddp_model.to(device) assert False, "Changing devices should be caught and not supported" except AssertionError: pass dist.destroy_process_group()
def run_test_gpt2(rank, world_size, backend, device, temp_file_name, reduce_buffer_size): INPUT_DIM = 16 BACH_SIZE = 10 STEPS = 10 url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) model = GPT2(embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=reduce_buffer_size) # Move the model to another device post-construction model = model.to(device) # Optim loop set_to_none = True def closure(): nonlocal set_to_none ddp_model.zero_grad(set_to_none=set_to_none) set_to_none = not set_to_none # Force int inputs to prevent the first grad from firing input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device) loss = ddp_model(input_tensor).abs().sum() loss.backward() return loss # Check for bucketing overflows for i in range(STEPS): _ = optimizer.step(closure=closure) # Stress test the .to() method ddp_model.to(device=device, dtype=torch.float16) ddp_model.to(device=device, dtype=torch.float32) dist.destroy_process_group()
def run_test_device_change(rank, world_size, backend, device, temp_file_name): # Check that the wrapped module can change devices url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) model = Sequential(Linear(2, 3), Linear(3, 3)).cpu() optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer) ddp_model.to(device) inputs = torch.rand((10, 2), device=device) outputs = ddp_model(inputs) # assert if the module has not been changed properly loss = outputs.norm().backward() dist.destroy_process_group()