def test_train_eval_change(): # Check that ShardedDDP handles the switch from training to eval properly dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1) model = _get_mlp() model.train() optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) model = ShardedDataParallel(model, optimizer) input_tensor = torch.rand((2, 2)) loss = model(input_tensor).sum() loss.backward() # make sure that the gradients are reduced # Wipe the gradients and switch to eval mode model.zero_grad() model.eval() _ = model(input_tensor) assert next(model.parameters()).grad is None or torch.norm( next(model.parameters()).grad) < 1e-6 # Get back to training model = model.train() model(input_tensor).sum().backward() assert torch.norm(next(model.parameters()).grad) > 0.0 dist.destroy_process_group()
def run_test_training_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size): group = dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) torch.cuda.set_device(rank) model = Sequential(Linear(2, 3), Linear(3, 3)).to(device) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer, process_group=group, reduce_buffer_size=reduce_buffer_size) inputs = torch.rand((10, 2), device=device) outputs = ddp_model( inputs) # assert if the module has not been changed properly _ = outputs.norm().backward() ddp_model.eval() ddp_model( inputs ) # This will assert if eval() is not properly taken into account ddp_model(inputs) dist.destroy_process_group()
def run_eval_mode(_unused): """ Testing eval mode make sure this is no asserts. """ dist.init_process_group(init_method=f"file://{tempfile.mkstemp()[1]}", backend=dist.Backend.GLOO, rank=0, world_size=1) model = Sequential(Linear(2, 3), Linear(3, 4)) optimizer_params = {"lr": 0.1, "momentum": 0.99} ddp = ShardedDataParallel(model, torch.optim.SGD, optimizer_params, 1, broadcast_buffers=False) optimizer = ddp.optimizer ddp.eval() for _ in range(5): input_tensor = torch.rand((64, 2)) output = ddp(input_tensor) ddp.train() try: for _ in range(5): input_tensor = torch.rand((64, 2)) output = ddp(input_tensor) except RuntimeError: pass else: assert False, "Multiple forward passes on training mode should not pass" dist.destroy_process_group()