def test_per_sample_grads_embeddingnet(self, device): class SampleNet(nn.Module): def __init__(self, vocab_size: int): super().__init__() self.emb = nn.Embedding(vocab_size, 16) self.fc1 = nn.Linear(16, 16) self.fc2 = nn.Linear(16, 2) def forward(self, x): x = self.emb(x) x = torch.transpose(x, -1, -2) x = torch.mean(x, -1) x = self.fc1(x) x = F.relu(x) x = self.fc2(x) return x def name(self): return "SampleNet" # Create our inputs... vocab_size = 1000 batch_shape = [64] words_per_sentence = 5 data = torch.randint(0, vocab_size, (*batch_shape, words_per_sentence), device=device) targets = torch.randint(0, 1, (*batch_shape,), device=device) # Construct our module net = SampleNet(vocab_size).to(device=device) criterion = nn.CrossEntropyLoss() params = dict(net.named_parameters()) weights, net_func, _ = make_functional(net) def compute_loss(weights, data, target): output = net_func(weights, (data,)) result = criterion(output, target) return result expected = [grad(compute_loss)(weights, data[i], targets[i]) for i in range(64)] expected = zip(*expected) expected = tuple(torch.stack(shards) for shards in expected) result = vmap(partial(grad(compute_loss), weights))(data, targets) for r, e in zip(result, expected): # TODO: Check if the rtol is a problem self.assertEqual(r, e, atol=0, rtol=1e-4)
self.mod = nn.Sequential(*mods) def forward(self, x): return (self.mod(x)**2).sum() batch_size = 16 features = 64 num_layers = 8 inp = torch.randn((batch_size, features)) mod = Foo(num_layers, features) jit_mod = torch.jit.script(mod) func_model, weights = make_functional(mod) lr = 1.0 def functional_step(x, weights): weights = [weight.detach().requires_grad_() for weight in weights] out = func_model(weights, x) out.backward() new_weights = [weight - lr * weight.grad for weight in weights] return out, new_weights optim = torch.optim.SGD(jit_mod.parameters(), lr=lr, momentum=0, dampening=0,
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) def forward(self, x): x = self.fc1(x) x = F.relu(x) x = self.fc2(x) x = F.log_softmax(x, -1) return x loss_fn = nn.NLLLoss() # Step 3: Make the model functional(!!) and define a training function. # NB: this mechanism doesn't exist in PyTorch today, but we want it to: # https://github.com/pytorch/pytorch/issues/49171 func_model, weights = make_functional(MLPClassifier().to(DEVICE)) def train_step_fn(weights, batch, targets, lr=0.2): def compute_loss(weights, batch, targets): output = func_model(weights, batch) loss = loss_fn(output, targets) return loss grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets) # NB: PyTorch is missing a "functional optimizer API" (possibly coming soon) # so we are going to re-implement SGD here. new_weights = [] with torch.no_grad(): for grad_weight, weight in zip(grad_weights, weights):
x = self.fc1(x) x = self.relu1(x) x = self.fc2(x) x = self.relu2(x) x = self.fc3(x) return x # TODO: Use F.mse_loss def mse_loss(x, y): return torch.mean((x - y)**2) net, params = make_functional(ThreeLayerNet()) opt = torch.optim.Adam(params, lr=1e-3) alpha = 0.1 K = 20 losses = [] num_tasks = 4 def sample_tasks(outer_batch_size, inner_batch_size): # Select amplitude and phase for the task As = [] phases = [] for _ in range(outer_batch_size): As.append(np.random.uniform(low=0.1, high=.5)) phases.append(np.random.uniform(low=0., high=np.pi))
def forward(self, x): x = self.fc1(x) x = self.relu1(x) x = self.fc2(x) x = self.relu2(x) x = self.fc3(x) return x # The prototype doesn't like F.mse_loss. def mse_loss(x, y): return torch.mean((x - y)**2) params, net, _ = make_functional(ThreeLayerNet()) opt = torch.optim.Adam(params, lr=1e-3) alpha = 0.1 K = 20 losses = [] num_tasks = 4 def sample_tasks(outer_batch_size, inner_batch_size): # Select amplitude and phase for the task As = [] phases = [] for _ in range(outer_batch_size): As.append(np.random.uniform(low=0.1, high=.5)) phases.append(np.random.uniform(low=0., high=np.pi))
def test_resnet18_per_sample_grads(self, device): # Straight out of opacus def _replace_child( root: nn.Module, child_name: str, converter: Callable[[nn.Module], nn.Module] ) -> None: # find the immediate parent parent = root nameList = child_name.split(".") for name in nameList[:-1]: parent = parent._modules[name] # set to identity parent._modules[nameList[-1]] = converter(parent._modules[nameList[-1]]) def replace_all_modules( root: nn.Module, target_class: Type[nn.Module], converter: Callable[[nn.Module], nn.Module], ) -> nn.Module: # base case if isinstance(root, target_class): return converter(root) for name, obj in root.named_modules(): if isinstance(obj, target_class): _replace_child(root, name, converter) return root def _batchnorm_to_groupnorm(module: nn.modules.batchnorm._BatchNorm) -> nn.Module: return nn.GroupNorm(min(32, module.num_features), module.num_features, affine=True) def convert_batchnorm_modules( model: nn.Module, converter: Callable[ [nn.modules.batchnorm._BatchNorm], nn.Module ] = _batchnorm_to_groupnorm, ) -> nn.Module: return replace_all_modules(model, nn.modules.batchnorm._BatchNorm, converter) import torchvision.models as models model = convert_batchnorm_modules(models.resnet18(num_classes=10)).to(device) criterion = nn.CrossEntropyLoss() weights, func_model, descriptors = make_functional(model) def compute_loss(weights, image, target): images = image.unsqueeze(0) targets = target.unsqueeze(0) output = func_model(weights, (images,)) loss = criterion(output, targets) return loss batch_size = 3 images = torch.randn(batch_size, 3, 32, 32, device=device) targets = torch.randint(0, 10, (batch_size,), device=device) result_grads = vmap(grad(compute_loss), in_dims=(None, 0, 0))(weights, images, targets) expected_grads = [ torch.autograd.grad(compute_loss(weights, images[i], targets[i]), weights) for i in range(batch_size) ] expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)] self.assertEqual(result_grads, expected_grads)
def init_fn(num_models): models = tuple(MLPClassifier().to(device) for _ in range(num_models)) weights = tuple(make_functional(model)[0] for model in models) weights = tuple(zip(*weights)) weights = tuple(torch.stack(shards).detach() for shards in weights) return weights
def test_ensemble_regression(self, device): def make_spirals(n_samples, noise_std=0., rotations=1.): ts = torch.linspace(0, 1, n_samples) rs = ts ** 0.5 thetas = rs * rotations * 2 * math.pi signs = torch.randint(0, 2, (n_samples,)) * 2 - 1 labels = (signs > 0).to(torch.long) xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples) * noise_std ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples) * noise_std points = torch.stack([xs, ys], dim=1) return points.to(device), labels.to(device) points, labels = make_spirals(100, noise_std=0.05) class MLPClassifier(nn.Module): def __init__(self, hidden_dim=32, n_classes=2): super().__init__() self.hidden_dim = hidden_dim self.n_classes = n_classes self.fc1 = nn.Linear(2, self.hidden_dim) self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) def forward(self, x): x = self.fc1(x) x = F.relu(x) x = self.fc2(x) x = F.log_softmax(x, -1) return x loss_fn = nn.NLLLoss() weights, func_model, _ = make_functional(MLPClassifier().to(device)) def train_step_fn(use_transform, weights, batch, targets, lr=0.2): def compute_loss(weights, batch, targets): output = func_model(weights, (batch,)) loss = loss_fn(output, targets) return loss if use_transform: grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets) else: loss = compute_loss(weights, batch, targets) grad_weights = torch.autograd.grad(loss, weights) new_weights = [] with torch.no_grad(): for grad_weight, weight in zip(grad_weights, weights): new_weights.append(weight - grad_weight * lr) # NB: return looks weird because torch.vmap must return Tensors return (loss, *new_weights) def unpack(train_result): return train_result[0], train_result[1:] def init_fn(num_models): models = tuple(MLPClassifier().to(device) for _ in range(num_models)) weights = tuple(make_functional(model)[0] for model in models) weights = tuple(zip(*weights)) weights = tuple(torch.stack(shards).detach() for shards in weights) return weights def slice_weights(batched_weights, index): return tuple(weight[index].detach().requires_grad_() for weight in batched_weights) batched_weights = init_fn(num_models=2) parallel_train_step_fn = vmap(partial(train_step_fn, True), in_dims=(0, None, None)) result_loss, result_weights = unpack(parallel_train_step_fn(batched_weights, points, labels)) loss0, weights0 = unpack(train_step_fn(False, slice_weights(batched_weights, 0), points, labels)) loss1, weights1 = unpack(train_step_fn(False, slice_weights(batched_weights, 1), points, labels)) expected_loss = torch.stack([loss0, loss1]) expected_weights = tuple(torch.stack([w0, w1]) for w0, w1 in zip(weights0, weights1)) self.assertEqual(result_loss, expected_loss) self.assertEqual(result_weights, expected_weights)
def test_maml_regression(self, device): class ThreeLayerNet(nn.Module): def __init__(self): super(ThreeLayerNet, self).__init__() self.fc1 = nn.Linear(1, 40) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(40, 40) self.relu2 = nn.ReLU() self.fc3 = nn.Linear(40, 1) def forward(self, x): x = self.fc1(x) x = self.relu1(x) x = self.fc2(x) x = self.relu2(x) x = self.fc3(x) return x # The prototype doesn't like F.mse_loss. def mse_loss(x, y): return torch.mean((x - y) ** 2) params, net, _ = make_functional(ThreeLayerNet().to(device)) K = 20 losses = [] num_tasks = 4 alpha = 0.1 def sample_tasks(outer_batch_size, inner_batch_size): # Select amplitude and phase for the task As = [] phases = [] for _ in range(outer_batch_size): As.append(np.random.uniform(low=0.1, high=.5)) phases.append(np.random.uniform(low=0., high=np.pi)) def get_batch(): xs, ys = [], [] for A, phase in zip(As, phases): x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1)) y = A * np.sin(x + phase) xs.append(x) ys.append(y) return torch.tensor(xs, dtype=torch.float, device=device), \ torch.tensor(ys, dtype=torch.float, device=device) x1, y1 = get_batch() x2, y2 = get_batch() return x1, y1, x2, y2 def get_loss_for_task(use_transform, x1, y1, x2, y2): def inner_loss(params, x1, y1): f = net(params, (x1,)) loss = mse_loss(f, y1) return loss if use_transform: grads = grad(inner_loss)(params, x1, y1) else: loss = inner_loss(params, x1, y1) grads = torch.autograd.grad(loss, params, create_graph=True) new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))] v_f = net(new_params, (x2,)) return mse_loss(v_f, y2) task = sample_tasks(num_tasks, K) # Compute with vmap+grad inner_losses = vmap(partial(get_loss_for_task, True))\ (task[0], task[1], task[2], task[3]) loss2 = sum(inner_losses)/len(inner_losses) result_grads = torch.autograd.grad(loss2, params) # Compute without vmap+grad inner_losses = [ get_loss_for_task(False, task[0][i], task[1][i], task[2][i], task[3][i]) for i in range(num_tasks) ] loss2 = sum(inner_losses)/len(inner_losses) expected_grads = torch.autograd.grad(loss2, params) self.assertEqual(result_grads, expected_grads)
def train(args, model, train_loader, optimizer, epoch, device): start_time = datetime.now() criterion = nn.CrossEntropyLoss() losses = [] top1_acc = [] for i, (images, target) in enumerate(tqdm(train_loader)): images = images.to(device) target = target.to(device) # Step 1: compute per-sample-grads # In order to use functional vmap+grad, we need to be able to # pass the weights to a model. func_model, weights = make_functional(model) # To use vmap+grad to compute per-sample-grads, the forward pass # must be re-formulated on a single example. # We use the `grad` operator to compute forward+backward on a single example, # and finally `vmap` to do forward+backward on multiple examples. def compute_loss_and_output(weights, image, target): images = image.unsqueeze(0) targets = target.unsqueeze(0) output = func_model(weights, images) loss = criterion(output, targets) return loss, output.squeeze(0) # `grad(f)` is a functional API that returns a function `f'` that # computes gradients by running both the forward and backward pass. # We want to extract some intermediate # values from the computation (i.e. the loss and output). # # To extract the loss, we use the `grad_and_value` API, that returns the # gradient of the weights w.r.t. the loss and the loss. # # To extract the output, we use the `has_aux=True` flag. # `has_aux=True` assumes that `f` returns a tuple of two values, # where the first is to be differentiated and the second "auxiliary value" # is not to be differentiated. `f'` returns the gradient w.r.t. the loss, # the loss, and the auxiliary value. grads_loss_output = grad_and_value(compute_loss_and_output, has_aux=True) sample_grads, (sample_loss, output) = \ vmap(grads_loss_output, (None, 0, 0))(weights, images, target) loss = sample_loss.mean() for grad_sample, weight in zip(sample_grads, model.parameters()): weight.grad_sample = grad_sample.detach() # Step 2: Clip the per-sample-grads, sum them to form grads, and add noise clip_and_accumulate_and_add_noise(model, args.max_per_sample_grad_norm, args.sigma) preds = np.argmax(output.detach().cpu().numpy(), axis=1) labels = target.detach().cpu().numpy() losses.append(loss.item()) # measure accuracy and record loss acc1 = accuracy(preds, labels) top1_acc.append(acc1) # make sure we take a step after processing the last mini-batch in the # epoch to ensure we start the next epoch with a clean state optimizer.step() optimizer.zero_grad() if i % args.print_freq == 0: print(f"\tTrain Epoch: {epoch} \t" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {np.mean(top1_acc):.6f} ") train_duration = datetime.now() - start_time return train_duration
self.mod = nn.Sequential(*mods) def forward(self, x): return (self.mod(x)**2).sum() batch_size = 16 features = 64 num_layers = 8 inp = torch.randn((batch_size, features)) mod = Foo(num_layers, features) jit_mod = torch.jit.script(mod) weights, func_model, _ = make_functional(mod) lr = 1.0 def functional_step(x, weights): grads, value = grad_and_value(func_model)(weights, (x, )) new_weights = [ weight - lr * p_grad for p_grad, weight in zip(grads, weights) ] return value, new_weights optim = torch.optim.SGD(jit_mod.parameters(), lr=lr, momentum=0, dampening=0,