def test_single_run(): device, offload_device = _init() model = _get_model() peak_mem = {} for checkpoint_activation in [True, False]: offload_model = OffloadModel( model=model, device=device, offload_device=offload_device, num_slices=2, checkpoint_activation=checkpoint_activation, ) offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) input = torch.ones(1000, 2).to(device) labels = torch.ones(1000, 2).to(device) offload_model.train() pred = offload_model(input) loss_fn = torch.nn.MSELoss(reduction="sum") loss = loss_fn(pred, labels) loss.backward() offload_optimizer.step() key = "ca_" + str(checkpoint_activation) peak_mem[key] = torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] print("Peak allocated bytes on cuda:0 for checkpoint_activation " + str(checkpoint_activation) + ": {:2f}".format(peak_mem[key])) # TODO(anj-s): We need a better requirement since this fails on CircleCI right now. assert peak_mem["ca_True"] <= peak_mem["ca_False"]
def get_model_and_optimizer(args, device, benchmark_config, model_specs): """Return instantiated model and optimizer function.""" if args.model_name == "lm": model = get_lm_model(args, device, model_specs) lr = benchmark_config["lr"] def make_adam(params): return Adam(params, lr=lr) optimizer = make_adam elif args.model_name == "seq": model = get_seq_model(args, device, model_specs) optimizer = torch.optim.SGD model = OffloadModel( model_cpu=model, device=torch.device("cuda"), offload_device=torch.device("cpu"), num_slices=benchmark_config["slices"], checkpoint_activation=benchmark_config["checkpoint_activation"], num_microbatches=benchmark_config["num_microbatches"], ) return model, optimizer
def _train_offload_model(model, device, offload_device, use_fp16=False, checkpoint_activation=False, num_microbatches=1): omodel = copy.deepcopy(model) offload_model = OffloadModel( model=omodel, device=device, offload_device=offload_device, num_slices=2, checkpoint_activation=checkpoint_activation, num_microbatches=num_microbatches, ) offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) return _train(offload_model, offload_optimizer, use_fp16, device)
def test_single_run(): device, offload_device = _init() model = _get_model() offload_model = OffloadModel( model=model, device=device, offload_device=offload_device, num_slices=2, ) offload_optimizer = torch.optim.SGD(offload_model.parameters(), lr=0.001) input = torch.ones(2, 2).to(device) labels = torch.ones(2, 2).to(device) offload_model.train() pred = offload_model(input) loss_fn = torch.nn.MSELoss(reduction="sum") loss = loss_fn(pred, labels) loss.backward() offload_optimizer.step()