def test_memory_tracking_nlp_model(): """ Check that we can collect memory traces of a realistic model outside of the context of distributed training (DDP or FSDP) """ BACH_SIZE = 10 INPUT_DIM = 16 model = GPT2(embed_dim=256, num_heads=2, num_layers=6, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2).cuda() tracker = LayerwiseMemoryTracker() tracker.monitor(model) input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).cuda() output = model(input_tensor) output.sum().backward() assert len(tracker.memory_traces) > 0, "failed to collected memory traces" assert len( tracker.forward_traces) > 0, "failed to collect forward memory traces" assert len(tracker.backward_traces ) > 0, "failed to collect backward memory traces" assert tracker.summary.total_activation_allocations == 12462080
def run_test_gpt2(rank, world_size, backend, device, temp_file_name): INPUT_DIM = 32 BACH_SIZE = 10 STEPS = 10 url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) if device == torch.device("cuda"): torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) model = GPT2( embed_dim=512, num_heads=2, num_layers=24, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2 ).to(device) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer) # Optim loop def closure(): optimizer.zero_grad() # Force int inputs to prevent the first grad from firing input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device) loss = ddp_model(input_tensor).abs().sum() loss.backward() return loss # Check for bucketing overflows for i in range(STEPS): _ = optimizer.step(closure=closure) dist.destroy_process_group()
def run_test_gpt2(rank, world_size, backend, device, temp_file_name, reduce_buffer_size): INPUT_DIM = 16 BACH_SIZE = 10 STEPS = 10 url = "file://" + temp_file_name dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) model = GPT2(embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=reduce_buffer_size) # Move the model to another device post-construction model = model.to(device) # Optim loop set_to_none = True def closure(): nonlocal set_to_none ddp_model.zero_grad(set_to_none=set_to_none) set_to_none = not set_to_none # Force int inputs to prevent the first grad from firing input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device) loss = ddp_model(input_tensor).abs().sum() loss.backward() return loss # Check for bucketing overflows for i in range(STEPS): _ = optimizer.step(closure=closure) # Stress test the .to() method ddp_model.to(device=device, dtype=torch.float16) ddp_model.to(device=device, dtype=torch.float32) dist.destroy_process_group()