Пример #1
0
    def get_ham(self, ham_wo_k, kpt):
        """ calculates hamiltonian for different k-points
        Args:
            ham_wo_k: hamiltonian matrix [N_p*N_o, N_p*N_o, N_images]
            kpt: Array of coordinates of the k-points, e.g. for gamma: jnp.array([[0, 0, 0]])
            shifts: uses 2D shifts matrix of compute_shifts function
            lattice: Lattice vector as 2D matrix e.g.: jnp.array([[1.0, 0.0, 0], [0.5, jnp.sqrt(3.0)/2.0, 0], [0, 0, 10]])
        Returns:
            Hamiltonian for all k-points (N_k, N_p*N_o, N_p*N_o)
        """
        phase_matrix = vmap(vmap(calc_phase_matrix, (None, 0, None)),
                            (0, None, None))(kpt, self.shifts,
                                             self.lattice)  # (N_k, N_images)

        # expand both matrices for automatic broadcasting # (N_k, 1,1, N_images)
        g_mat = torch.unsqueeze(torch.unsqueeze(phase_matrix, axis=1), axis=1)
        ham_wo_k = torch.unsqueeze(
            ham_wo_k, axis=0
        )  # (1,particle number*N_orbitals, particle number*N_orbitals, N_images)
        hamiltonian = ham_wo_k * g_mat  # (N_k, particle number*N_orbitals, particle number*N_orbitals, N_images)
        hamiltonian = torch.sum(
            hamiltonian, axis=-1
        )  # (N_k, particle number*N_orbitals, particle number*N_orbitals)
        hamiltonian = torch.where(
            torch.abs(hamiltonian) < 1e-10,
            torch.zeros(hamiltonian.shape, dtype=torch.cdouble), hamiltonian)
        # hamiltonian += vmap(set_diagonal_to_inf, 0)(hamiltonian)  # in calculation function
        return hamiltonian
Пример #2
0
def map_product(distance):
    """ vmap is used to effectively calculate the distances of the cartesian product of the particles
    Args:
        distance: distance_fn that accepts ((N1,N2,dim), (N3,dim)) arrays as input

    Returns: map prduct of distance
    """
    return vmap(vmap(vmap(distance, (0, None), 0), (1, None), 1), (None, 0), 0)
Пример #3
0
def get_fallback_and_vmap_exhaustive(op,
                                     arg_values,
                                     kwarg_values,
                                     opinfo=None,
                                     compute_loop_out=True,
                                     bdims=(0, -1)):
    out_dim = 0
    batch_size = 4
    generator = get_exhaustive_batched_inputs(arg_values,
                                              kwarg_values,
                                              batch_size,
                                              bdims=bdims)
    batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm"
                      )  # instance norm calls batch norm
    if opinfo is not None and opinfo.name in batch_norm_fns:
        generator = get_exhaustive_batched_inputs_for_batch_norm(arg_values,
                                                                 kwarg_values,
                                                                 batch_size,
                                                                 bdims=bdims)
    for batched_args, in_dims, kwarg_values in generator:
        if compute_loop_out:
            loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args,
                            **kwarg_values)
        else:
            loop_out = None
        # Used for debugging the resulting operations
        # from functorch import make_fx
        # def f(a):
        #     return op(a)
        # t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values)
        # print(in_dims, [arg.shape for arg in batched_args], kwarg_values)
        batched_out = vmap(op, in_dims=in_dims,
                           out_dims=out_dim)(*batched_args, **kwarg_values)
        yield (loop_out, batched_out)

        # Tests case where we dispatch to a batching rule with no bdims
        # This should be handled by autogenerated plumbing. For vmap support
        # added via a manual plumbing you may need to handle this specially.
        def add_bdim_if_tensor(x):
            if isinstance(x, torch.Tensor):
                return x.unsqueeze(1)
            return x

        def f(dummy, *args, **kwargs):
            return op(*args, **kwargs)

        dummy = torch.ones(batch_size, 1)
        expected = pytree.tree_map(add_bdim_if_tensor, batched_out)

        inner_in_dims = (0, ) + pytree.tree_map(lambda x: None, in_dims)
        outer_in_dims = (0, ) + in_dims
        output = vmap(vmap(f, inner_in_dims),
                      outer_in_dims)(dummy, *batched_args, **kwarg_values)
        yield (expected, output)
Пример #4
0
def compute_quantities_for_vmap_test(op,
                                     orig_batched_args,
                                     orig_kwarg_values,
                                     in_dims,
                                     out_dim=0,
                                     batch_size=2,
                                     compute_loop_out=True,
                                     clone_inputs=False):
    def maybe_clone_inputs():
        if clone_inputs:
            batched_args = pytree.tree_map(clone_if_tensor, orig_batched_args)
            kwarg_values = pytree.tree_map(clone_if_tensor, orig_kwarg_values)
            return batched_args, kwarg_values
        return orig_batched_args, orig_kwarg_values

    batched_args, kwarg_values = maybe_clone_inputs()
    if compute_loop_out:
        loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args,
                        **kwarg_values)
    else:
        loop_out = None
    # Used for debugging the resulting operations
    # from functorch import make_fx
    # def f(a):
    #     return op(a)
    # t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values)
    # print(in_dims, [arg.shape for arg in batched_args], kwarg_values)
    batched_args, kwarg_values = maybe_clone_inputs()
    batched_out = vmap(op, in_dims=in_dims, out_dims=out_dim)(*batched_args,
                                                              **kwarg_values)
    yield (loop_out, batched_out)

    # Tests case where we dispatch to a batching rule with no bdims
    # This should be handled by autogenerated plumbing. For vmap support
    # added via a manual plumbing you may need to handle this specially.
    def add_bdim_if_tensor(x):
        if isinstance(x, torch.Tensor):
            return x.unsqueeze(1)
        return x

    def f(dummy, *args, **kwargs):
        return op(*args, **kwargs)

    dummy = torch.ones(batch_size, 1)
    expected = pytree.tree_map(add_bdim_if_tensor, batched_out)

    inner_in_dims = (0, ) + pytree.tree_map(lambda x: None, in_dims)
    outer_in_dims = (0, ) + in_dims
    batched_args, kwarg_values = maybe_clone_inputs()
    output = vmap(vmap(f, inner_in_dims), outer_in_dims)(dummy, *batched_args,
                                                         **kwarg_values)
    yield (expected, output)
Пример #5
0
 def forward(self, positions, species, kpts):
     ham_wo_k, overlap_wo_k = self.create_hamiltonian_wo_k(
         positions, species)
     hamiltonian = self.get_ham(ham_wo_k, kpts)
     hamiltonian = hamiltonian + vmap(set_diagonal_to_inf, 0)(hamiltonian)
     overlap_matrix = self.get_ham(overlap_wo_k, kpts)
     overlap_matrix = overlap_matrix + torch.unsqueeze(
         torch.diag(torch.ones(overlap_matrix.shape[1])), 0)
     hamiltonian = torch.where(
         torch.abs(hamiltonian) < 10e-10,
         torch.zeros(hamiltonian.shape, dtype=torch.cdouble),
         hamiltonian)  # useless?
     overlap_matrix = torch.where(
         torch.abs(overlap_matrix) < 10e-10,
         torch.zeros(overlap_matrix.shape, dtype=torch.cdouble),
         overlap_matrix)
     overlap_eig_val = []
     overlap_eig_vec = []
     for i in range(
             overlap_matrix.shape[0]):  # list comprehension or tensor?
         overlap_op = xitorch.LinearOperator.m(overlap_matrix[i])
         overlap_eig_single = xilinalg.symeig(overlap_op)
         overlap_eig_val.append(overlap_eig_single[0])
         overlap_eig_vec.append(overlap_eig_single[1])
     overlap_eig = (torch.stack(overlap_eig_val),
                    torch.stack(overlap_eig_vec))
     v_div_sqrt_eig = (
         overlap_eig[1] * 1 /
         torch.sqrt(overlap_eig[0]).unsqueeze(dim=-1)).permute(0, 2, 1)
     overlap_root_inv = vmap(torch.matmul,
                             0)(v_div_sqrt_eig,
                                overlap_eig[1].permute(0, 2, 1).conj())
     new_ham = torch.matmul(
         overlap_root_inv,
         torch.matmul(hamiltonian,
                      overlap_root_inv.permute(0, 2, 1).conj())
     )  #vmap(torch.matmul, 0, 0)(overlap_inverse, hamiltonian)
     new_ham = torch.where(
         torch.abs(new_ham) < 10e-10,
         torch.zeros(new_ham.shape, dtype=torch.cdouble),
         new_ham)  # useless?
     solution = []
     for i in range(new_ham.shape[0]):
         hamiltonian_op = xitorch.LinearOperator.m(new_ham[i])
         solution.append(xilinalg.symeig(hamiltonian_op)[0])
     solution = torch.stack(solution)
     solution = solution - find_fermi(solution, self.highest_occupied)
     solution = solution * 27.211396  # conversion au (atomic unit) to eV
     return solution  # solution_corrected
Пример #6
0
def make_prediction(model, drs):
    norms = torch.norm(drs, dim=1).reshape(-1, 1)
    energies = model(norms)

    network_derivs = vmap(jacrev(model))(norms).squeeze(-1)
    forces = -network_derivs * drs / norms
    return energies, forces
Пример #7
0
    def test_vjp_vmap(self, device):
        x = torch.randn(3, device=device)
        y, vjp_fn = vjp(vmap(torch.sin), x)
        self.assertEqual(y, x.sin())

        v = torch.randn(3, device=device)
        self.assertEqual(vjp_fn(v)[0], x.cos() * v)
Пример #8
0
    def test_vmap_vjp(self, device):
        x = torch.randn(3, device=device)
        _, vjp_fn = vjp(torch.sin, x)

        def foo(x):
            _, vjp_fn = vjp(torch.sin, x)
            return vjp_fn(x)

        y = vmap(foo)(x)
        self.assertEqual(y, vjp_fn(x))

        # TODO: there's a very interesting error message when the following
        # is on CPU
        xs = torch.randn(5, 3, device=device)
        expected = torch.stack([vjp_fn(x)[0] for x in xs])
        result = vmap(lambda x: vjp_fn(x)[0])(xs)
        self.assertEqual(result, expected)
Пример #9
0
def step6():
    parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None))
    batched_weights = init_fn(num_models=2)
    for i in range(2000):
        loss, batched_weights = parallel_train_step_fn(batched_weights, points,
                                                       labels)
        if i % 200 == 0:
            print(loss)
Пример #10
0
 def test_make_fx_vmap(self, device):
     def f(x):
         return torch.sin(x)
     inp = torch.randn(5, 3)
     f = vmap(f)
     fx_f = make_fx(f)(inp)
     new_inp = torch.randn(5, 3)
     self.assertEqual(fx_f(new_inp), f(new_inp))
Пример #11
0
    def get_params_diag(self, dr, species_a, species_b):
        """
        Args:
            dr: 2D matrix of distances of particles
            species_a: element of first particle
            species_b: element of second particle
            kwargs: Dict of params for SK e.g, {"V_sss":0.2, ...}
        Returns: param_vec: slayter-koster parameters for given particles and distances
        """
        # (N_particle, N_particle, N_images, 1 -> N-parameters through broadcasting)
        param_diag_vec = torch.unsqueeze(torch.zeros(dr.shape), axis=-1)

        for i in range(self.species_count):
            mask = torch.where(species_a == i, 1.0, 0.0)
            param_diag_vec = param_diag_vec + (
                self.calc_diag_params(dr, i) * torch.unsqueeze(mask, axis=-1)
            )  # shapes are brodcasted
        param_diag_vec = vmap(vmap(vmap(torch.diag, 0), 0), 0)(param_diag_vec)
        return param_diag_vec
Пример #12
0
def functorch_per_sample_grad():
    compute_grad = grad(compute_loss)
    compute_per_sample_grad = vmap(compute_grad, (None, 0, 0))

    start = time.time()
    result = compute_per_sample_grad(weights, images, targets)
    torch.cuda.synchronize()
    end = time.time()

    return result, end - start  # end - start in seconds
Пример #13
0
    def test_per_sample_grads_simple(self, device):
        def compute_loss(weight, x, t):
            y = x @ weight
            return ((y - t) ** 2).sum()

        weight = torch.randn(16, 2, device=device)
        x = torch.randn(64, 16, device=device)
        t = torch.randn(64, 2, device=device)
        result = vmap(partial(grad(compute_loss), weight))(x, t)
        expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)]
        expected = torch.stack(expected)
        # TODO: Check if the rtol is a problem
        self.assertEqual(result, expected, atol=0, rtol=5e-4)
Пример #14
0
    def test_new_empty_materializes_tensor(self, device):
        N = 3
        C = 5

        def foo(y, x):
            result = x.new_empty((C,))
            result.copy_(y)
            return result.sum()

        x = torch.randn(N, device=device)
        y = torch.randn(N, C, device=device)
        result = vmap(grad(foo))(y, x)
        self.assertEqual(result, torch.ones_like(y))
Пример #15
0
    def test_per_sample_grads_embeddingnet(self, device):
        class SampleNet(nn.Module):
            def __init__(self, vocab_size: int):
                super().__init__()
                self.emb = nn.Embedding(vocab_size, 16)
                self.fc1 = nn.Linear(16, 16)
                self.fc2 = nn.Linear(16, 2)

            def forward(self, x):
                x = self.emb(x)
                x = torch.transpose(x, -1, -2)
                x = torch.mean(x, -1)
                x = self.fc1(x)
                x = F.relu(x)
                x = self.fc2(x)
                return x

            def name(self):
                return "SampleNet"

        # Create our inputs...
        vocab_size = 1000
        batch_shape = [64]
        words_per_sentence = 5
        data = torch.randint(0, vocab_size, (*batch_shape, words_per_sentence), device=device)
        targets = torch.randint(0, 1, (*batch_shape,), device=device)

        # Construct our module
        net = SampleNet(vocab_size).to(device=device)
        criterion = nn.CrossEntropyLoss()

        params = dict(net.named_parameters())
        weights, net_func, _ = make_functional(net)

        def compute_loss(weights, data, target):
            output = net_func(weights, (data,))
            result = criterion(output, target)
            return result

        expected = [grad(compute_loss)(weights, data[i], targets[i]) for i in range(64)]
        expected = zip(*expected)
        expected = tuple(torch.stack(shards) for shards in expected)

        result = vmap(partial(grad(compute_loss), weights))(data, targets)
        for r, e in zip(result, expected):
            # TODO: Check if the rtol is a problem
            self.assertEqual(r, e, atol=0, rtol=1e-4)
Пример #16
0
def train(db, net, device, meta_opt, epoch, log):
    params, buffers, fnet = net
    n_train_iter = db.x_train.shape[0] // db.batchsz

    for batch_idx in range(n_train_iter):
        start_time = time.time()
        # Sample a batch of support and query images and labels.
        x_spt, y_spt, x_qry, y_qry = db.next()

        task_num, setsz, c_, h, w = x_spt.size()

        n_inner_iter = 5
        meta_opt.zero_grad()

        # In parallel, trains one model per task. There is a support (x, y)
        # for each task and a query (x, y) for each task.
        compute_loss_for_task = functools.partial(loss_for_task, net,
                                                  n_inner_iter)
        qry_losses, qry_accs = vmap(compute_loss_for_task)(x_spt, y_spt, x_qry,
                                                           y_qry)

        # Compute the maml loss by summing together the returned losses.
        qry_losses.sum().backward()

        meta_opt.step()
        qry_losses = qry_losses.detach().sum() / task_num
        qry_accs = 100. * qry_accs.sum() / task_num
        i = epoch + float(batch_idx) / n_train_iter
        iter_time = time.time() - start_time
        if batch_idx % 4 == 0:
            print(
                f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
            )

        log.append({
            'epoch': i,
            'loss': qry_losses,
            'acc': qry_accs,
            'mode': 'train',
            'time': time.time(),
        })
Пример #17
0
 def test_vmap_vmap(self, device):
     x = torch.randn(2, 3, device=device)
     y = vmap(vmap(torch.sin))(x)
     self.assertEqual(y, x.sin())
Пример #18
0
    def get_loss_for_task(x1, y1, x2, y2):
        def inner_loss(params, x1, y1):
            f = net(params, x1)
            loss = mse_loss(f, y1)
            return loss

        grads = grad(inner_loss)(tuple(params), x1, y1)
        new_params = [(params[i] - alpha * grads[i])
                      for i in range(len(params))]

        v_f = net(new_params, x2)
        return mse_loss(v_f, y2)

    task = sample_tasks(num_tasks, K)
    inner_losses = vmap(get_loss_for_task)(task[0], task[1], task[2], task[3])
    loss2 = sum(inner_losses) / len(inner_losses)
    loss2.backward()

    opt.step()

    if it % 100 == 0:
        print('Iteration %d -- Outer Loss: %.4f' % (it, loss2))
    losses.append(loss2)

t_A = torch.tensor(0.0).uniform_(0.1, 0.5)
t_b = torch.tensor(0.0).uniform_(0.0, math.pi)

t_x = torch.empty(4, 1).uniform_(-5, 5)
t_y = t_A * torch.sin(t_x + t_b)
Пример #19
0
    def test_maml_regression(self, device):
        class ThreeLayerNet(nn.Module):
            def __init__(self):
                super(ThreeLayerNet, self).__init__()
                self.fc1 = nn.Linear(1, 40)
                self.relu1 = nn.ReLU()
                self.fc2 = nn.Linear(40, 40)
                self.relu2 = nn.ReLU()
                self.fc3 = nn.Linear(40, 1)

            def forward(self, x):
                x = self.fc1(x)
                x = self.relu1(x)
                x = self.fc2(x)
                x = self.relu2(x)
                x = self.fc3(x)
                return x

        # The prototype doesn't like F.mse_loss.
        def mse_loss(x, y):
            return torch.mean((x - y) ** 2)

        params, net, _ = make_functional(ThreeLayerNet().to(device))
        K = 20
        losses = []
        num_tasks = 4
        alpha = 0.1

        def sample_tasks(outer_batch_size, inner_batch_size):
            # Select amplitude and phase for the task
            As = []
            phases = []
            for _ in range(outer_batch_size):
                As.append(np.random.uniform(low=0.1, high=.5))
                phases.append(np.random.uniform(low=0., high=np.pi))
            def get_batch():
                xs, ys = [], []
                for A, phase in zip(As, phases):
                    x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1))
                    y = A * np.sin(x + phase)
                    xs.append(x)
                    ys.append(y)
                return torch.tensor(xs, dtype=torch.float, device=device), \
                    torch.tensor(ys, dtype=torch.float, device=device)
            x1, y1 = get_batch()
            x2, y2 = get_batch()
            return x1, y1, x2, y2

        def get_loss_for_task(use_transform, x1, y1, x2, y2):
            def inner_loss(params, x1, y1):
                f = net(params, (x1,))
                loss = mse_loss(f, y1)
                return loss

            if use_transform:
                grads = grad(inner_loss)(params, x1, y1)
            else:
                loss = inner_loss(params, x1, y1)
                grads = torch.autograd.grad(loss, params, create_graph=True)
            new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))]

            v_f = net(new_params, (x2,))
            return mse_loss(v_f, y2)

        task = sample_tasks(num_tasks, K)

        # Compute with vmap+grad
        inner_losses = vmap(partial(get_loss_for_task, True))\
                            (task[0], task[1], task[2], task[3])
        loss2 = sum(inner_losses)/len(inner_losses)
        result_grads = torch.autograd.grad(loss2, params)

        # Compute without vmap+grad
        inner_losses = [
            get_loss_for_task(False, task[0][i], task[1][i], task[2][i], task[3][i])
            for i in range(num_tasks)
        ]
        loss2 = sum(inner_losses)/len(inner_losses)
        expected_grads = torch.autograd.grad(loss2, params)

        self.assertEqual(result_grads, expected_grads)
Пример #20
0
    targets = target.unsqueeze(0)
    predictions = fmodel(params, buffers, batch)
    loss = loss_fn(predictions, targets)
    return loss


# Now, let's use ``grad`` to create a new function that computes the gradient
# with respect to the first argument of compute_loss (i.e. the params).
ft_compute_grad = grad(compute_loss)

# ``ft_compute_grad`` computes the gradient for a single (sample, target) pair.
# We can use ``vmap`` to get it to compute the gradient over an entire batch
# of samples and targets. Note that in_dims=(None, None, 0, 0) because we wish
# to map ``ft_compute_grad`` over the 0th dimension of the data and targets
# and use the same params and buffers for each.
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))

# Finally, let's used our transformed function to compute per-sample-gradients:
ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)
for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads,
                                               ft_per_sample_grads):
    assert torch.allclose(per_sample_grad,
                          ft_per_sample_grad,
                          atol=1e-6,
                          rtol=1e-6)

# A quick note: there are limitations around what types of functions can be
# transformed by vmap. The best functions to transform are ones that are
# pure functions: a function where the outputs are only determined by the inputs
# that have no side effects (e.g. mutation). vmap is unable to handle mutation of
# arbitrary Python data structures, but it is able to handle many in-place
Пример #21
0
def compute_jac(xp):
    jacobian_rows = [
        torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
        for vec in unit_vectors
    ]
    return torch.stack(jacobian_rows)


jacobian = compute_jac(xp)

# Instead of computing the jacobian row-by-row, we can use ``vmap`` to get rid
# of the for-loop and vectorize the computation. We can't directly apply vmap
# to PyTorch Autograd; instead, functorch provides a ``vjp`` transform:
from functorch import vmap, vjp
_, vjp_fn = vjp(partial(predict, weight, bias), x)
ft_jacobian, = vmap(vjp_fn)(unit_vectors)
assert torch.allclose(ft_jacobian, jacobian)

# In another tutorial a composition of reverse-mode AD and vmap gave us
# per-sample-gradients. In this tutorial, composing reverse-mode AD and vmap
# gives us Jacobian computation! Various compositions of vmap and autodiff
# transforms can give us different interesting quantities.
#
# functorch provides ``jacrev`` as a convenience function that performs
# the vmap-vjp composition to compute jacobians. ``jacrev`` accepts an argnums
# argument that says which argument we would like to compute Jacobians with
# respect to.
from functorch import jacrev
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
assert torch.allclose(ft_jacobian, jacobian)
Пример #22
0
    def test_ensemble_regression(self, device):
        def make_spirals(n_samples, noise_std=0., rotations=1.):
            ts = torch.linspace(0, 1, n_samples)
            rs = ts ** 0.5
            thetas = rs * rotations * 2 * math.pi
            signs = torch.randint(0, 2, (n_samples,)) * 2 - 1
            labels = (signs > 0).to(torch.long)

            xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples) * noise_std
            ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples) * noise_std
            points = torch.stack([xs, ys], dim=1)
            return points.to(device), labels.to(device)

        points, labels = make_spirals(100, noise_std=0.05)

        class MLPClassifier(nn.Module):
            def __init__(self, hidden_dim=32, n_classes=2):
                super().__init__()
                self.hidden_dim = hidden_dim
                self.n_classes = n_classes

                self.fc1 = nn.Linear(2, self.hidden_dim)
                self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)

            def forward(self, x):
                x = self.fc1(x)
                x = F.relu(x)
                x = self.fc2(x)
                x = F.log_softmax(x, -1)
                return x

        loss_fn = nn.NLLLoss()

        weights, func_model, _ = make_functional(MLPClassifier().to(device))

        def train_step_fn(use_transform, weights, batch, targets, lr=0.2):
            def compute_loss(weights, batch, targets):
                output = func_model(weights, (batch,))
                loss = loss_fn(output, targets)
                return loss

            if use_transform:
                grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets)
            else:
                loss = compute_loss(weights, batch, targets)
                grad_weights = torch.autograd.grad(loss, weights)

            new_weights = []
            with torch.no_grad():
                for grad_weight, weight in zip(grad_weights, weights):
                    new_weights.append(weight - grad_weight * lr)
            # NB: return looks weird because torch.vmap must return Tensors
            return (loss, *new_weights)

        def unpack(train_result):
            return train_result[0], train_result[1:]

        def init_fn(num_models):
            models = tuple(MLPClassifier().to(device) for _ in range(num_models))
            weights = tuple(make_functional(model)[0] for model in models)
            weights = tuple(zip(*weights))
            weights = tuple(torch.stack(shards).detach() for shards in weights)
            return weights

        def slice_weights(batched_weights, index):
            return tuple(weight[index].detach().requires_grad_() for weight in batched_weights)

        batched_weights = init_fn(num_models=2)
        parallel_train_step_fn = vmap(partial(train_step_fn, True), in_dims=(0, None, None))

        result_loss, result_weights = unpack(parallel_train_step_fn(batched_weights, points, labels))

        loss0, weights0 = unpack(train_step_fn(False, slice_weights(batched_weights, 0), points, labels))
        loss1, weights1 = unpack(train_step_fn(False, slice_weights(batched_weights, 1), points, labels))
        expected_loss = torch.stack([loss0, loss1])
        expected_weights = tuple(torch.stack([w0, w1]) for w0, w1 in zip(weights0, weights1))

        self.assertEqual(result_loss, expected_loss)
        self.assertEqual(result_weights, expected_weights)
Пример #23
0
    def create_hamiltonian_wo_k(self, positions, species):
        """
        Args:
            positions: particle position matrix 2D
            species: array of species
            shifts: uses 2D shifts matrix of comupte_shifts function
            kwargs: Dict of 2D matrix of slyer-koster parameters
            kwargs_diag: Dict of 2D matrix of one-site parameters
            kwargs_overlap: Dict of 2D matrix of off-site overlap parameters
        Returns:
            hamiltonian wo k as matrix
        """
        if self.spd:
            n_orbitals = 9
            get_hop_int = get_hop_int_spd
        else:
            n_orbitals = 4
            get_hop_int = get_hop_int_sp
        # n_orbitals = 9  # 4 for sp, 9 for spd
        create_bondmatrix_mask = bondmatrix_masking(self.cutoff)
        shifted_positions = vmap(shift_fn, (0, None))(positions, self.shifts)
        shifted_pair_distance_vectors = shifted_positions.view(positions.shape[0], 1, self.shifts.shape[0], 3) -\
        positions.view(1, positions.shape[0], 1, 3)

        # expand species shape to be the same as the shifted coordinates
        # DEVICE ????????????????
        shifted_species = torch.repeat_interleave(torch.unsqueeze(species,
                                                                  axis=0),
                                                  self.shifts.shape[0],
                                                  dim=0).T
        # flatten first dimension for cartesian product
        species_b = torch.repeat_interleave(torch.unsqueeze(shifted_species,
                                                            axis=1),
                                            species.shape[0],
                                            dim=1)
        species_a = torch.repeat_interleave(torch.unsqueeze(shifted_species,
                                                            axis=1),
                                            species.shape[0],
                                            dim=1).permute(1, 0, 2)
        # separate into two vectors for particle pairs a,b and reshape to (particle number, particle_number, N_images)

        # DEVICE ????????????????
        dir_cos = get_dir_cos(
            shifted_pair_distance_vectors
        )  # (particle number, particle number, N_images, 3)
        pair_distances = torch.linalg.norm(
            shifted_pair_distance_vectors,
            axis=-1)  # (particle number, particle number, N_images, dim)
        bondmatrix = create_bondmatrix_mask(pair_distances)
        param_vec_1 = self.get_params(pair_distances, species_a, species_b,
                                      'V')
        param_vec_2 = self.get_params(pair_distances, species_b, species_a,
                                      'V')
        param_vec_1 = param_vec_1 * torch.unsqueeze(bondmatrix, axis=-1)
        param_vec_2 = param_vec_2 * torch.unsqueeze(bondmatrix, axis=-1)
        hamiltonian = get_hop_int(
            torch.cat([param_vec_1, dir_cos, param_vec_2],
                      axis=-1)).permute(2, 3, 4, 0, 1)
        param_diag = self.get_params_diag(pair_distances, species_a, 'diag')
        hamiltonian = hamiltonian + param_diag

        overlap_vec_1 = self.get_params(pair_distances, species_a, species_b,
                                        'S')
        overlap_vec_2 = self.get_params(pair_distances, species_b, species_a,
                                        'S')
        overlap_vec_1 = overlap_vec_1 * torch.unsqueeze(bondmatrix, axis=-1)
        overlap_vec_2 = overlap_vec_2 * torch.unsqueeze(bondmatrix, axis=-1)

        overlap_matrix = get_hop_int(
            torch.cat([overlap_vec_1, dir_cos, overlap_vec_2],
                      axis=-1)).permute(2, 3, 4, 0, 1)
        hamiltonian = torch.permute(hamiltonian, (0, 3, 1, 4, 2)) \
            .reshape(species.shape[0] * n_orbitals, species.shape[0] * n_orbitals, self.shifts.shape[0])
        overlap_matrix = torch.permute(overlap_matrix, (0, 3, 1, 4, 2)) \
            .reshape(species.shape[0] * n_orbitals, species.shape[0] * n_orbitals, self.shifts.shape[0])
        return hamiltonian, overlap_matrix
Пример #24
0
 def test_vmap_grad(self, device):
     x = torch.randn(3, device=device)
     y = vmap(grad(torch.sin))(x)
     self.assertEqual(y, x.cos())
Пример #25
0
 def test_vmap_on_jacrev_simple(self, device):
     x = torch.randn(2, 3, device=device)
     y = vmap(jacrev(torch.sin))(x)
     expected = torch.stack([torch.diagflat(x[i].cos()) for i in range(2)])
     assert torch.allclose(y, expected)
def train(args, model, train_loader, optimizer, epoch, device):
    start_time = datetime.now()

    criterion = nn.CrossEntropyLoss()

    losses = []
    top1_acc = []

    for i, (images, target) in enumerate(tqdm(train_loader)):

        images = images.to(device)
        target = target.to(device)

        # Step 1: compute per-sample-grads

        # In order to use functional vmap+grad, we need to be able to
        # pass the weights to a model.
        func_model, weights = make_functional(model)

        # To use vmap+grad to compute per-sample-grads, the forward pass
        # must be re-formulated on a single example.
        # We use the `grad` operator to compute forward+backward on a single example,
        # and finally `vmap` to do forward+backward on multiple examples.
        def compute_loss_and_output(weights, image, target):
            images = image.unsqueeze(0)
            targets = target.unsqueeze(0)
            output = func_model(weights, images)
            loss = criterion(output, targets)
            return loss, output.squeeze(0)

        # `grad(f)` is a functional API that returns a function `f'` that
        # computes gradients by running both the forward and backward pass.
        # We want to extract some intermediate
        # values from the computation (i.e. the loss and output).
        #
        # To extract the loss, we use the `grad_and_value` API, that returns the
        # gradient of the weights w.r.t. the loss and the loss.
        #
        # To extract the output, we use the `has_aux=True` flag.
        # `has_aux=True` assumes that `f` returns a tuple of two values,
        # where the first is to be differentiated and the second "auxiliary value"
        # is not to be differentiated. `f'` returns the gradient w.r.t. the loss,
        # the loss, and the auxiliary value.
        grads_loss_output = grad_and_value(compute_loss_and_output,
                                           has_aux=True)
        sample_grads, (sample_loss, output) = \
            vmap(grads_loss_output, (None, 0, 0))(weights, images, target)
        loss = sample_loss.mean()

        for grad_sample, weight in zip(sample_grads, model.parameters()):
            weight.grad_sample = grad_sample.detach()

        # Step 2: Clip the per-sample-grads, sum them to form grads, and add noise
        clip_and_accumulate_and_add_noise(model, args.max_per_sample_grad_norm,
                                          args.sigma)

        preds = np.argmax(output.detach().cpu().numpy(), axis=1)
        labels = target.detach().cpu().numpy()
        losses.append(loss.item())

        # measure accuracy and record loss
        acc1 = accuracy(preds, labels)

        top1_acc.append(acc1)

        # make sure we take a step after processing the last mini-batch in the
        # epoch to ensure we start the next epoch with a clean state
        optimizer.step()
        optimizer.zero_grad()

        if i % args.print_freq == 0:
            print(f"\tTrain Epoch: {epoch} \t"
                  f"Loss: {np.mean(losses):.6f} "
                  f"Acc@1: {np.mean(top1_acc):.6f} ")
    train_duration = datetime.now() - start_time
    return train_duration
Пример #27
0
    def test_resnet18_per_sample_grads(self, device):
        # Straight out of opacus
        def _replace_child(
            root: nn.Module, child_name: str, converter: Callable[[nn.Module], nn.Module]
        ) -> None:
            # find the immediate parent
            parent = root
            nameList = child_name.split(".")
            for name in nameList[:-1]:
                parent = parent._modules[name]
            # set to identity
            parent._modules[nameList[-1]] = converter(parent._modules[nameList[-1]])

        def replace_all_modules(
            root: nn.Module,
            target_class: Type[nn.Module],
            converter: Callable[[nn.Module], nn.Module],
        ) -> nn.Module:
            # base case
            if isinstance(root, target_class):
                return converter(root)

            for name, obj in root.named_modules():
                if isinstance(obj, target_class):
                    _replace_child(root, name, converter)
            return root

        def _batchnorm_to_groupnorm(module: nn.modules.batchnorm._BatchNorm) -> nn.Module:
            return nn.GroupNorm(min(32, module.num_features), module.num_features, affine=True)

        def convert_batchnorm_modules(
            model: nn.Module,
            converter: Callable[
                [nn.modules.batchnorm._BatchNorm], nn.Module
            ] = _batchnorm_to_groupnorm,
        ) -> nn.Module:
            return replace_all_modules(model, nn.modules.batchnorm._BatchNorm, converter)

        import torchvision.models as models
        model = convert_batchnorm_modules(models.resnet18(num_classes=10)).to(device)
        criterion = nn.CrossEntropyLoss()

        weights, func_model, descriptors = make_functional(model)

        def compute_loss(weights, image, target):
            images = image.unsqueeze(0)
            targets = target.unsqueeze(0)
            output = func_model(weights, (images,))
            loss = criterion(output, targets)
            return loss

        batch_size = 3
        images = torch.randn(batch_size, 3, 32, 32, device=device)
        targets = torch.randint(0, 10, (batch_size,), device=device)

        result_grads = vmap(grad(compute_loss), in_dims=(None, 0, 0))(weights, images, targets)

        expected_grads = [
            torch.autograd.grad(compute_loss(weights, images[i], targets[i]), weights)
            for i in range(batch_size)
        ]
        expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)]

        self.assertEqual(result_grads, expected_grads)
Пример #28
0
# stateless version of the model (fmodel) and stacked parameters and buffers.
from functorch import combine_state_for_ensemble

fmodel, params, buffers = combine_state_for_ensemble(models)
[p.requires_grad_() for p in params]

# Option 1: get predictions using a different minibatch for each model.
# By default, vmap maps a function across the first dimension of all inputs to the
# passed-in function. After `combine_state_for_ensemble`, each of of ``params``,
# ``buffers`` have an additional dimension of size ``num_models`` at the front;
# and ``minibatches`` has a dimension of size ``num_models``.
print([p.size(0) for p in params])
assert minibatches.shape == (num_models, 64, 1, 28, 28)
from functorch import vmap

predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)
assert torch.allclose(predictions1_vmap,
                      torch.stack(predictions1),
                      atol=1e-6,
                      rtol=1e-6)

# Option 2: get predictions using the same minibatch of data
# vmap has an in_dims arg that specify which dimensions to map over.
# Using ``None``, we tell vmap we want the same minibatch to apply for all of
# the 10 models.
predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers,
                                                       minibatch)
assert torch.allclose(predictions2_vmap,
                      torch.stack(predictions2),
                      atol=1e-6,
                      rtol=1e-6)
Пример #29
0
 def foo(x):
     y = vmap(torch.sin)(x)
     return y.sum()
Пример #30
0
    def test_maml_omniglot(self, device):
        # TODO: there appears to be precision issues for float32
        dtype = torch.double

        # TODO: The prototype doesn't support in-place relu (and some other
        # in-place operations. That can be fixed.)
        inplace_relu = False
        n_way = 5
        n_inner_iter = 2
        num_tasks = 2
        class Flatten(nn.Module):
            def forward(self, input):
                return input.view(input.size(0), -1)

        net = nn.Sequential(
            nn.Conv2d(1, 64, 3),
            nn.BatchNorm2d(64, momentum=1, affine=True),
            nn.ReLU(inplace=inplace_relu),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 64, 3),
            nn.BatchNorm2d(64, momentum=1, affine=True),
            nn.ReLU(inplace=inplace_relu),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 64, 3),
            nn.BatchNorm2d(64, momentum=1, affine=True),
            nn.ReLU(inplace=inplace_relu),
            nn.MaxPool2d(2, 2),
            Flatten(),
            nn.Linear(64, n_way)).to(device).to(dtype)

        params, buffers, fnet, _, _, = make_functional_with_buffers(net)
        net = (params, buffers, fnet)

        def loss_for_task(net, n_inner_iter, use_transform, x_spt, y_spt, x_qry, y_qry):
            params, buffers, fnet = net
            querysz = x_qry.size(0)

            def compute_loss(new_params, buffers, x, y):
                logits = fnet(new_params, buffers, (x,))
                loss = F.cross_entropy(logits, y)
                return loss

            new_params = params
            for _ in range(n_inner_iter):
                if use_transform:
                    grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt)
                else:
                    res = compute_loss(new_params, buffers, x_spt, y_spt)
                    grads = torch.autograd.grad(res, new_params, create_graph=True)
                new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)]

            qry_logits = fnet(new_params, buffers, (x_qry,))
            qry_loss = F.cross_entropy(qry_logits, y_qry)
            qry_acc = (qry_logits.argmax(
                dim=1) == y_qry).sum() / querysz

            return qry_loss, qry_acc

        # Get some sample inputs...
        x_spt = torch.randn(num_tasks, 25, 1, 28, 28, dtype=dtype, device=device)
        y_spt = torch.randint(0, 5, (num_tasks, 25), device=device)
        x_qry = torch.randn(num_tasks, 75, 1, 28, 28, dtype=dtype,device=device)
        y_qry = torch.randint(0, 5, (num_tasks, 75), device=device)

        # compute with vmap + grad
        compute_loss = partial(loss_for_task, net, n_inner_iter, True)
        qry_losses, _ = vmap(compute_loss)(x_spt, y_spt, x_qry, y_qry)
        result_grads = torch.autograd.grad(qry_losses.sum(), params)

        # compute without vmap + grad
        compute_loss = partial(loss_for_task, net, n_inner_iter, False)
        losses = [compute_loss(x_spt[i], y_spt[i], x_qry[i], y_qry[i])[0]
                  for i in range(num_tasks)]
        expected_grads = torch.autograd.grad(sum(losses), params)

        self.assertEqual(result_grads, expected_grads)