def test_per_sample_grads_embeddingnet(self, device):
        class SampleNet(nn.Module):
            def __init__(self, vocab_size: int):
                self.emb = nn.Embedding(vocab_size, 16)
                self.fc1 = nn.Linear(16, 16)
                self.fc2 = nn.Linear(16, 2)

            def forward(self, x):
                x = self.emb(x)
                x = torch.transpose(x, -1, -2)
                x = torch.mean(x, -1)
                x = self.fc1(x)
                x = F.relu(x)
                x = self.fc2(x)
                return x

            def name(self):
                return "SampleNet"

        # Create our inputs...
        vocab_size = 1000
        batch_shape = [64]
        words_per_sentence = 5
        data = torch.randint(0, vocab_size, (*batch_shape, words_per_sentence), device=device)
        targets = torch.randint(0, 1, (*batch_shape,), device=device)

        # Construct our module
        net = SampleNet(vocab_size).to(device=device)
        criterion = nn.CrossEntropyLoss()

        params = dict(net.named_parameters())
        weights, net_func, _ = make_functional(net)

        def compute_loss(weights, data, target):
            output = net_func(weights, (data,))
            result = criterion(output, target)
            return result

        expected = [grad(compute_loss)(weights, data[i], targets[i]) for i in range(64)]
        expected = zip(*expected)
        expected = tuple(torch.stack(shards) for shards in expected)

        result = vmap(partial(grad(compute_loss), weights))(data, targets)
        for r, e in zip(result, expected):
            # TODO: Check if the rtol is a problem
            self.assertEqual(r, e, atol=0, rtol=1e-4)
        self.mod = nn.Sequential(*mods)

    def forward(self, x):
        return (self.mod(x)**2).sum()

batch_size = 16
features = 64
num_layers = 8
inp = torch.randn((batch_size, features))

mod = Foo(num_layers, features)

jit_mod = torch.jit.script(mod)

func_model, weights = make_functional(mod)
lr = 1.0

def functional_step(x, weights):
    weights = [weight.detach().requires_grad_() for weight in weights]
    out = func_model(weights, x)
    new_weights = [weight - lr * weight.grad for weight in weights]
    return out, new_weights

optim = torch.optim.SGD(jit_mod.parameters(), lr=lr)
        self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.log_softmax(x, -1)
        return x

loss_fn = nn.NLLLoss()

# Step 3: Make the model functional(!!) and define a training function.
# NB: this mechanism doesn't exist in PyTorch today, but we want it to:
func_model, weights = make_functional(MLPClassifier().to(DEVICE))

def train_step_fn(weights, batch, targets, lr=0.2):
    def compute_loss(weights, batch, targets):
        output = func_model(weights, batch)
        loss = loss_fn(output, targets)
        return loss

    grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets)

    # NB: PyTorch is missing a "functional optimizer API" (possibly coming soon)
    # so we are going to re-implement SGD here.
    new_weights = []
    with torch.no_grad():
        for grad_weight, weight in zip(grad_weights, weights):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return x

# TODO: Use F.mse_loss

def mse_loss(x, y):
    return torch.mean((x - y)**2)

net, params = make_functional(ThreeLayerNet())
opt = torch.optim.Adam(params, lr=1e-3)
alpha = 0.1

K = 20
losses = []
num_tasks = 4

def sample_tasks(outer_batch_size, inner_batch_size):
    # Select amplitude and phase for the task
    As = []
    phases = []
    for _ in range(outer_batch_size):
        As.append(np.random.uniform(low=0.1, high=.5))
        phases.append(np.random.uniform(low=0., high=np.pi))
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return x

# The prototype doesn't like F.mse_loss.
def mse_loss(x, y):
    return torch.mean((x - y)**2)

params, net, _ = make_functional(ThreeLayerNet())
opt = torch.optim.Adam(params, lr=1e-3)
alpha = 0.1

K = 20
losses = []
num_tasks = 4

def sample_tasks(outer_batch_size, inner_batch_size):
    # Select amplitude and phase for the task
    As = []
    phases = []
    for _ in range(outer_batch_size):
        As.append(np.random.uniform(low=0.1, high=.5))
        phases.append(np.random.uniform(low=0., high=np.pi))
    def test_resnet18_per_sample_grads(self, device):
        # Straight out of opacus
        def _replace_child(
            root: nn.Module, child_name: str, converter: Callable[[nn.Module], nn.Module]
        ) -> None:
            # find the immediate parent
            parent = root
            nameList = child_name.split(".")
            for name in nameList[:-1]:
                parent = parent._modules[name]
            # set to identity
            parent._modules[nameList[-1]] = converter(parent._modules[nameList[-1]])

        def replace_all_modules(
            root: nn.Module,
            target_class: Type[nn.Module],
            converter: Callable[[nn.Module], nn.Module],
        ) -> nn.Module:
            # base case
            if isinstance(root, target_class):
                return converter(root)

            for name, obj in root.named_modules():
                if isinstance(obj, target_class):
                    _replace_child(root, name, converter)
            return root

        def _batchnorm_to_groupnorm(module: nn.modules.batchnorm._BatchNorm) -> nn.Module:
            return nn.GroupNorm(min(32, module.num_features), module.num_features, affine=True)

        def convert_batchnorm_modules(
            model: nn.Module,
            converter: Callable[
                [nn.modules.batchnorm._BatchNorm], nn.Module
            ] = _batchnorm_to_groupnorm,
        ) -> nn.Module:
            return replace_all_modules(model, nn.modules.batchnorm._BatchNorm, converter)

        import torchvision.models as models
        model = convert_batchnorm_modules(models.resnet18(num_classes=10)).to(device)
        criterion = nn.CrossEntropyLoss()

        weights, func_model, descriptors = make_functional(model)

        def compute_loss(weights, image, target):
            images = image.unsqueeze(0)
            targets = target.unsqueeze(0)
            output = func_model(weights, (images,))
            loss = criterion(output, targets)
            return loss

        batch_size = 3
        images = torch.randn(batch_size, 3, 32, 32, device=device)
        targets = torch.randint(0, 10, (batch_size,), device=device)

        result_grads = vmap(grad(compute_loss), in_dims=(None, 0, 0))(weights, images, targets)

        expected_grads = [
            torch.autograd.grad(compute_loss(weights, images[i], targets[i]), weights)
            for i in range(batch_size)
        expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)]

        self.assertEqual(result_grads, expected_grads)
 def init_fn(num_models):
     models = tuple(MLPClassifier().to(device) for _ in range(num_models))
     weights = tuple(make_functional(model)[0] for model in models)
     weights = tuple(zip(*weights))
     weights = tuple(torch.stack(shards).detach() for shards in weights)
     return weights
    def test_ensemble_regression(self, device):
        def make_spirals(n_samples, noise_std=0., rotations=1.):
            ts = torch.linspace(0, 1, n_samples)
            rs = ts ** 0.5
            thetas = rs * rotations * 2 * math.pi
            signs = torch.randint(0, 2, (n_samples,)) * 2 - 1
            labels = (signs > 0).to(torch.long)

            xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples) * noise_std
            ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples) * noise_std
            points = torch.stack([xs, ys], dim=1)

        return points, labels

        class MLPClassifier(nn.Module):
            def __init__(self, hidden_dim=32, n_classes=2):
                self.hidden_dim = hidden_dim
                self.n_classes = n_classes

                self.fc1 = nn.Linear(2, self.hidden_dim)
                self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)

            def forward(self, x):
                x = self.fc1(x)
                x = F.relu(x)
                x = self.fc2(x)
                x = F.log_softmax(x, -1)
                return x

        loss_fn = nn.NLLLoss()

        weights, func_model, _ = make_functional(MLPClassifier().to(device))

        def train_step_fn(use_transform, weights, batch, targets, lr=0.2):
            def compute_loss(weights, batch, targets):
                output = func_model(weights, (batch,))
                loss = loss_fn(output, targets)
                return loss

            if use_transform:
                grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets)
                loss = compute_loss(weights, batch, targets)
                grad_weights = torch.autograd.grad(loss, weights)

            new_weights = []
            with torch.no_grad():
                for grad_weight, weight in zip(grad_weights, weights):
                    new_weights.append(weight - grad_weight * lr)
            # NB: return looks weird because torch.vmap must return Tensors
            return (loss, *new_weights)

        def unpack(train_result):
            return train_result[0], train_result[1:]

        def init_fn(num_models):
            models = tuple(MLPClassifier().to(device) for _ in range(num_models))
            weights = tuple(make_functional(model)[0] for model in models)
            weights = tuple(zip(*weights))
            weights = tuple(torch.stack(shards).detach() for shards in weights)
            return weights

        def slice_weights(batched_weights, index):
            return tuple(weight[index].detach().requires_grad_() for weight in batched_weights)

        batched_weights = init_fn(num_models=2)
        parallel_train_step_fn = vmap(partial(train_step_fn, True), in_dims=(0, None, None))

        result_loss, result_weights = unpack(parallel_train_step_fn(batched_weights, points, labels))

        loss0, weights0 = unpack(train_step_fn(False, slice_weights(batched_weights, 0), points, labels))
        loss1, weights1 = unpack(train_step_fn(False, slice_weights(batched_weights, 1), points, labels))
        expected_loss = torch.stack([loss0, loss1])
        expected_weights = tuple(torch.stack([w0, w1]) for w0, w1 in zip(weights0, weights1))

        self.assertEqual(result_loss, expected_loss)
        self.assertEqual(result_weights, expected_weights)
    def test_maml_regression(self, device):
        class ThreeLayerNet(nn.Module):
            def __init__(self):
                super(ThreeLayerNet, self).__init__()
                self.fc1 = nn.Linear(1, 40)
                self.relu1 = nn.ReLU()
                self.fc2 = nn.Linear(40, 40)
                self.relu2 = nn.ReLU()
                self.fc3 = nn.Linear(40, 1)

            def forward(self, x):
                x = self.fc1(x)
                x = self.relu1(x)
                x = self.fc2(x)
                x = self.relu2(x)
                x = self.fc3(x)
                return x

        # The prototype doesn't like F.mse_loss.
        def mse_loss(x, y):
            return torch.mean((x - y) ** 2)

        params, net, _ = make_functional(ThreeLayerNet().to(device))
        K = 20
        losses = []
        num_tasks = 4
        alpha = 0.1

        def sample_tasks(outer_batch_size, inner_batch_size):
            # Select amplitude and phase for the task
            As = []
            phases = []
            for _ in range(outer_batch_size):
                As.append(np.random.uniform(low=0.1, high=.5))
                phases.append(np.random.uniform(low=0., high=np.pi))
            def get_batch():
                xs, ys = [], []
                for A, phase in zip(As, phases):
                    x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
                    y = A * np.sin(x + phase)
                return torch.tensor(xs, dtype=torch.float, device=device), \
                    torch.tensor(ys, dtype=torch.float, device=device)
            x1, y1 = get_batch()
            x2, y2 = get_batch()
            return x1, y1, x2, y2

        def get_loss_for_task(use_transform, x1, y1, x2, y2):
            def inner_loss(params, x1, y1):
                f = net(params, (x1,))
                loss = mse_loss(f, y1)
                return loss

            if use_transform:
                grads = grad(inner_loss)(params, x1, y1)
                loss = inner_loss(params, x1, y1)
                grads = torch.autograd.grad(loss, params, create_graph=True)
            new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))]

            v_f = net(new_params, (x2,))
            return mse_loss(v_f, y2)

        task = sample_tasks(num_tasks, K)

        # Compute with vmap+grad
        inner_losses = vmap(partial(get_loss_for_task, True))\
                            (task[0], task[1], task[2], task[3])
        loss2 = sum(inner_losses)/len(inner_losses)
        result_grads = torch.autograd.grad(loss2, params)

        # Compute without vmap+grad
        inner_losses = [
            get_loss_for_task(False, task[0][i], task[1][i], task[2][i], task[3][i])
            for i in range(num_tasks)
        loss2 = sum(inner_losses)/len(inner_losses)
        expected_grads = torch.autograd.grad(loss2, params)

        self.assertEqual(result_grads, expected_grads)
def train(args, model, train_loader, optimizer, epoch, device):
    start_time =

    criterion = nn.CrossEntropyLoss()

    losses = []
    top1_acc = []

    for i, (images, target) in enumerate(tqdm(train_loader)):

        images =
        target =

        # Step 1: compute per-sample-grads

        # In order to use functional vmap+grad, we need to be able to
        # pass the weights to a model.
        func_model, weights = make_functional(model)

        # To use vmap+grad to compute per-sample-grads, the forward pass
        # must be re-formulated on a single example.
        # We use the `grad` operator to compute forward+backward on a single example,
        # and finally `vmap` to do forward+backward on multiple examples.
        def compute_loss_and_output(weights, image, target):
            images = image.unsqueeze(0)
            targets = target.unsqueeze(0)
            output = func_model(weights, images)
            loss = criterion(output, targets)
            return loss, output.squeeze(0)

        # `grad(f)` is a functional API that returns a function `f'` that
        # computes gradients by running both the forward and backward pass.
        # We want to extract some intermediate
        # values from the computation (i.e. the loss and output).
        # To extract the loss, we use the `grad_and_value` API, that returns the
        # gradient of the weights w.r.t. the loss and the loss.
        # To extract the output, we use the `has_aux=True` flag.
        # `has_aux=True` assumes that `f` returns a tuple of two values,
        # where the first is to be differentiated and the second "auxiliary value"
        # is not to be differentiated. `f'` returns the gradient w.r.t. the loss,
        # the loss, and the auxiliary value.
        grads_loss_output = grad_and_value(compute_loss_and_output,
        sample_grads, (sample_loss, output) = \
            vmap(grads_loss_output, (None, 0, 0))(weights, images, target)
        loss = sample_loss.mean()

        for grad_sample, weight in zip(sample_grads, model.parameters()):
            weight.grad_sample = grad_sample.detach()

        # Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
        clip_and_accumulate_and_add_noise(model, args.max_per_sample_grad_norm,

        preds = np.argmax(output.detach().cpu().numpy(), axis=1)
        labels = target.detach().cpu().numpy()

        # measure accuracy and record loss
        acc1 = accuracy(preds, labels)


        # make sure we take a step after processing the last mini-batch in the
        # epoch to ensure we start the next epoch with a clean state

        if i % args.print_freq == 0:
            print(f"\tTrain Epoch: {epoch} \t"
                  f"Loss: {np.mean(losses):.6f} "
                  f"Acc@1: {np.mean(top1_acc):.6f} ")
    train_duration = - start_time
    return train_duration
        self.mod = nn.Sequential(*mods)

    def forward(self, x):
        return (self.mod(x)**2).sum()

batch_size = 16
features = 64
num_layers = 8
inp = torch.randn((batch_size, features))

mod = Foo(num_layers, features)

jit_mod = torch.jit.script(mod)

weights, func_model, _ = make_functional(mod)
lr = 1.0

def functional_step(x, weights):
    grads, value = grad_and_value(func_model)(weights, (x, ))
    new_weights = [
        weight - lr * p_grad for p_grad, weight in zip(grads, weights)
    return value, new_weights

optim = torch.optim.SGD(jit_mod.parameters(), lr=lr)