def get_ham(self, ham_wo_k, kpt): """ calculates hamiltonian for different k-points Args: ham_wo_k: hamiltonian matrix [N_p*N_o, N_p*N_o, N_images] kpt: Array of coordinates of the k-points, e.g. for gamma: jnp.array([[0, 0, 0]]) shifts: uses 2D shifts matrix of compute_shifts function lattice: Lattice vector as 2D matrix e.g.: jnp.array([[1.0, 0.0, 0], [0.5, jnp.sqrt(3.0)/2.0, 0], [0, 0, 10]]) Returns: Hamiltonian for all k-points (N_k, N_p*N_o, N_p*N_o) """ phase_matrix = vmap(vmap(calc_phase_matrix, (None, 0, None)), (0, None, None))(kpt, self.shifts, self.lattice) # (N_k, N_images) # expand both matrices for automatic broadcasting # (N_k, 1,1, N_images) g_mat = torch.unsqueeze(torch.unsqueeze(phase_matrix, axis=1), axis=1) ham_wo_k = torch.unsqueeze( ham_wo_k, axis=0 ) # (1,particle number*N_orbitals, particle number*N_orbitals, N_images) hamiltonian = ham_wo_k * g_mat # (N_k, particle number*N_orbitals, particle number*N_orbitals, N_images) hamiltonian = torch.sum( hamiltonian, axis=-1 ) # (N_k, particle number*N_orbitals, particle number*N_orbitals) hamiltonian = torch.where( torch.abs(hamiltonian) < 1e-10, torch.zeros(hamiltonian.shape, dtype=torch.cdouble), hamiltonian) # hamiltonian += vmap(set_diagonal_to_inf, 0)(hamiltonian) # in calculation function return hamiltonian
def map_product(distance): """ vmap is used to effectively calculate the distances of the cartesian product of the particles Args: distance: distance_fn that accepts ((N1,N2,dim), (N3,dim)) arrays as input Returns: map prduct of distance """ return vmap(vmap(vmap(distance, (0, None), 0), (1, None), 1), (None, 0), 0)
def get_fallback_and_vmap_exhaustive(op, arg_values, kwarg_values, opinfo=None, compute_loop_out=True, bdims=(0, -1)): out_dim = 0 batch_size = 4 generator = get_exhaustive_batched_inputs(arg_values, kwarg_values, batch_size, bdims=bdims) batch_norm_fns = ("nn.functional.batch_norm", "nn.functional.instance_norm" ) # instance norm calls batch norm if opinfo is not None and opinfo.name in batch_norm_fns: generator = get_exhaustive_batched_inputs_for_batch_norm(arg_values, kwarg_values, batch_size, bdims=bdims) for batched_args, in_dims, kwarg_values in generator: if compute_loop_out: loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values) else: loop_out = None # Used for debugging the resulting operations # from functorch import make_fx # def f(a): # return op(a) # t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values) # print(in_dims, [arg.shape for arg in batched_args], kwarg_values) batched_out = vmap(op, in_dims=in_dims, out_dims=out_dim)(*batched_args, **kwarg_values) yield (loop_out, batched_out) # Tests case where we dispatch to a batching rule with no bdims # This should be handled by autogenerated plumbing. For vmap support # added via a manual plumbing you may need to handle this specially. def add_bdim_if_tensor(x): if isinstance(x, torch.Tensor): return x.unsqueeze(1) return x def f(dummy, *args, **kwargs): return op(*args, **kwargs) dummy = torch.ones(batch_size, 1) expected = pytree.tree_map(add_bdim_if_tensor, batched_out) inner_in_dims = (0, ) + pytree.tree_map(lambda x: None, in_dims) outer_in_dims = (0, ) + in_dims output = vmap(vmap(f, inner_in_dims), outer_in_dims)(dummy, *batched_args, **kwarg_values) yield (expected, output)
def compute_quantities_for_vmap_test(op, orig_batched_args, orig_kwarg_values, in_dims, out_dim=0, batch_size=2, compute_loop_out=True, clone_inputs=False): def maybe_clone_inputs(): if clone_inputs: batched_args = pytree.tree_map(clone_if_tensor, orig_batched_args) kwarg_values = pytree.tree_map(clone_if_tensor, orig_kwarg_values) return batched_args, kwarg_values return orig_batched_args, orig_kwarg_values batched_args, kwarg_values = maybe_clone_inputs() if compute_loop_out: loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values) else: loop_out = None # Used for debugging the resulting operations # from functorch import make_fx # def f(a): # return op(a) # t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values) # print(in_dims, [arg.shape for arg in batched_args], kwarg_values) batched_args, kwarg_values = maybe_clone_inputs() batched_out = vmap(op, in_dims=in_dims, out_dims=out_dim)(*batched_args, **kwarg_values) yield (loop_out, batched_out) # Tests case where we dispatch to a batching rule with no bdims # This should be handled by autogenerated plumbing. For vmap support # added via a manual plumbing you may need to handle this specially. def add_bdim_if_tensor(x): if isinstance(x, torch.Tensor): return x.unsqueeze(1) return x def f(dummy, *args, **kwargs): return op(*args, **kwargs) dummy = torch.ones(batch_size, 1) expected = pytree.tree_map(add_bdim_if_tensor, batched_out) inner_in_dims = (0, ) + pytree.tree_map(lambda x: None, in_dims) outer_in_dims = (0, ) + in_dims batched_args, kwarg_values = maybe_clone_inputs() output = vmap(vmap(f, inner_in_dims), outer_in_dims)(dummy, *batched_args, **kwarg_values) yield (expected, output)
def forward(self, positions, species, kpts): ham_wo_k, overlap_wo_k = self.create_hamiltonian_wo_k( positions, species) hamiltonian = self.get_ham(ham_wo_k, kpts) hamiltonian = hamiltonian + vmap(set_diagonal_to_inf, 0)(hamiltonian) overlap_matrix = self.get_ham(overlap_wo_k, kpts) overlap_matrix = overlap_matrix + torch.unsqueeze( torch.diag(torch.ones(overlap_matrix.shape[1])), 0) hamiltonian = torch.where( torch.abs(hamiltonian) < 10e-10, torch.zeros(hamiltonian.shape, dtype=torch.cdouble), hamiltonian) # useless? overlap_matrix = torch.where( torch.abs(overlap_matrix) < 10e-10, torch.zeros(overlap_matrix.shape, dtype=torch.cdouble), overlap_matrix) overlap_eig_val = [] overlap_eig_vec = [] for i in range( overlap_matrix.shape[0]): # list comprehension or tensor? overlap_op = xitorch.LinearOperator.m(overlap_matrix[i]) overlap_eig_single = xilinalg.symeig(overlap_op) overlap_eig_val.append(overlap_eig_single[0]) overlap_eig_vec.append(overlap_eig_single[1]) overlap_eig = (torch.stack(overlap_eig_val), torch.stack(overlap_eig_vec)) v_div_sqrt_eig = ( overlap_eig[1] * 1 / torch.sqrt(overlap_eig[0]).unsqueeze(dim=-1)).permute(0, 2, 1) overlap_root_inv = vmap(torch.matmul, 0)(v_div_sqrt_eig, overlap_eig[1].permute(0, 2, 1).conj()) new_ham = torch.matmul( overlap_root_inv, torch.matmul(hamiltonian, overlap_root_inv.permute(0, 2, 1).conj()) ) #vmap(torch.matmul, 0, 0)(overlap_inverse, hamiltonian) new_ham = torch.where( torch.abs(new_ham) < 10e-10, torch.zeros(new_ham.shape, dtype=torch.cdouble), new_ham) # useless? solution = [] for i in range(new_ham.shape[0]): hamiltonian_op = xitorch.LinearOperator.m(new_ham[i]) solution.append(xilinalg.symeig(hamiltonian_op)[0]) solution = torch.stack(solution) solution = solution - find_fermi(solution, self.highest_occupied) solution = solution * 27.211396 # conversion au (atomic unit) to eV return solution # solution_corrected
def make_prediction(model, drs): norms = torch.norm(drs, dim=1).reshape(-1, 1) energies = model(norms) network_derivs = vmap(jacrev(model))(norms).squeeze(-1) forces = -network_derivs * drs / norms return energies, forces
def test_vjp_vmap(self, device): x = torch.randn(3, device=device) y, vjp_fn = vjp(vmap(torch.sin), x) self.assertEqual(y, x.sin()) v = torch.randn(3, device=device) self.assertEqual(vjp_fn(v)[0], x.cos() * v)
def test_vmap_vjp(self, device): x = torch.randn(3, device=device) _, vjp_fn = vjp(torch.sin, x) def foo(x): _, vjp_fn = vjp(torch.sin, x) return vjp_fn(x) y = vmap(foo)(x) self.assertEqual(y, vjp_fn(x)) # TODO: there's a very interesting error message when the following # is on CPU xs = torch.randn(5, 3, device=device) expected = torch.stack([vjp_fn(x)[0] for x in xs]) result = vmap(lambda x: vjp_fn(x)[0])(xs) self.assertEqual(result, expected)
def step6(): parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None)) batched_weights = init_fn(num_models=2) for i in range(2000): loss, batched_weights = parallel_train_step_fn(batched_weights, points, labels) if i % 200 == 0: print(loss)
def test_make_fx_vmap(self, device): def f(x): return torch.sin(x) inp = torch.randn(5, 3) f = vmap(f) fx_f = make_fx(f)(inp) new_inp = torch.randn(5, 3) self.assertEqual(fx_f(new_inp), f(new_inp))
def get_params_diag(self, dr, species_a, species_b): """ Args: dr: 2D matrix of distances of particles species_a: element of first particle species_b: element of second particle kwargs: Dict of params for SK e.g, {"V_sss":0.2, ...} Returns: param_vec: slayter-koster parameters for given particles and distances """ # (N_particle, N_particle, N_images, 1 -> N-parameters through broadcasting) param_diag_vec = torch.unsqueeze(torch.zeros(dr.shape), axis=-1) for i in range(self.species_count): mask = torch.where(species_a == i, 1.0, 0.0) param_diag_vec = param_diag_vec + ( self.calc_diag_params(dr, i) * torch.unsqueeze(mask, axis=-1) ) # shapes are brodcasted param_diag_vec = vmap(vmap(vmap(torch.diag, 0), 0), 0)(param_diag_vec) return param_diag_vec
def functorch_per_sample_grad(): compute_grad = grad(compute_loss) compute_per_sample_grad = vmap(compute_grad, (None, 0, 0)) start = time.time() result = compute_per_sample_grad(weights, images, targets) torch.cuda.synchronize() end = time.time() return result, end - start # end - start in seconds
def test_per_sample_grads_simple(self, device): def compute_loss(weight, x, t): y = x @ weight return ((y - t) ** 2).sum() weight = torch.randn(16, 2, device=device) x = torch.randn(64, 16, device=device) t = torch.randn(64, 2, device=device) result = vmap(partial(grad(compute_loss), weight))(x, t) expected = [grad(compute_loss)(weight, x[i], t[i]) for i in range(64)] expected = torch.stack(expected) # TODO: Check if the rtol is a problem self.assertEqual(result, expected, atol=0, rtol=5e-4)
def test_new_empty_materializes_tensor(self, device): N = 3 C = 5 def foo(y, x): result = x.new_empty((C,)) result.copy_(y) return result.sum() x = torch.randn(N, device=device) y = torch.randn(N, C, device=device) result = vmap(grad(foo))(y, x) self.assertEqual(result, torch.ones_like(y))
def test_per_sample_grads_embeddingnet(self, device): class SampleNet(nn.Module): def __init__(self, vocab_size: int): super().__init__() self.emb = nn.Embedding(vocab_size, 16) self.fc1 = nn.Linear(16, 16) self.fc2 = nn.Linear(16, 2) def forward(self, x): x = self.emb(x) x = torch.transpose(x, -1, -2) x = torch.mean(x, -1) x = self.fc1(x) x = F.relu(x) x = self.fc2(x) return x def name(self): return "SampleNet" # Create our inputs... vocab_size = 1000 batch_shape = [64] words_per_sentence = 5 data = torch.randint(0, vocab_size, (*batch_shape, words_per_sentence), device=device) targets = torch.randint(0, 1, (*batch_shape,), device=device) # Construct our module net = SampleNet(vocab_size).to(device=device) criterion = nn.CrossEntropyLoss() params = dict(net.named_parameters()) weights, net_func, _ = make_functional(net) def compute_loss(weights, data, target): output = net_func(weights, (data,)) result = criterion(output, target) return result expected = [grad(compute_loss)(weights, data[i], targets[i]) for i in range(64)] expected = zip(*expected) expected = tuple(torch.stack(shards) for shards in expected) result = vmap(partial(grad(compute_loss), weights))(data, targets) for r, e in zip(result, expected): # TODO: Check if the rtol is a problem self.assertEqual(r, e, atol=0, rtol=1e-4)
def train(db, net, device, meta_opt, epoch, log): params, buffers, fnet = net n_train_iter = db.x_train.shape[0] // db.batchsz for batch_idx in range(n_train_iter): start_time = time.time() # Sample a batch of support and query images and labels. x_spt, y_spt, x_qry, y_qry = db.next() task_num, setsz, c_, h, w = x_spt.size() n_inner_iter = 5 meta_opt.zero_grad() # In parallel, trains one model per task. There is a support (x, y) # for each task and a query (x, y) for each task. compute_loss_for_task = functools.partial(loss_for_task, net, n_inner_iter) qry_losses, qry_accs = vmap(compute_loss_for_task)(x_spt, y_spt, x_qry, y_qry) # Compute the maml loss by summing together the returned losses. qry_losses.sum().backward() meta_opt.step() qry_losses = qry_losses.detach().sum() / task_num qry_accs = 100. * qry_accs.sum() / task_num i = epoch + float(batch_idx) / n_train_iter iter_time = time.time() - start_time if batch_idx % 4 == 0: print( f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' ) log.append({ 'epoch': i, 'loss': qry_losses, 'acc': qry_accs, 'mode': 'train', 'time': time.time(), })
def test_vmap_vmap(self, device): x = torch.randn(2, 3, device=device) y = vmap(vmap(torch.sin))(x) self.assertEqual(y, x.sin())
def get_loss_for_task(x1, y1, x2, y2): def inner_loss(params, x1, y1): f = net(params, x1) loss = mse_loss(f, y1) return loss grads = grad(inner_loss)(tuple(params), x1, y1) new_params = [(params[i] - alpha * grads[i]) for i in range(len(params))] v_f = net(new_params, x2) return mse_loss(v_f, y2) task = sample_tasks(num_tasks, K) inner_losses = vmap(get_loss_for_task)(task[0], task[1], task[2], task[3]) loss2 = sum(inner_losses) / len(inner_losses) loss2.backward() opt.step() if it % 100 == 0: print('Iteration %d -- Outer Loss: %.4f' % (it, loss2)) losses.append(loss2) t_A = torch.tensor(0.0).uniform_(0.1, 0.5) t_b = torch.tensor(0.0).uniform_(0.0, math.pi) t_x = torch.empty(4, 1).uniform_(-5, 5) t_y = t_A * torch.sin(t_x + t_b)
def test_maml_regression(self, device): class ThreeLayerNet(nn.Module): def __init__(self): super(ThreeLayerNet, self).__init__() self.fc1 = nn.Linear(1, 40) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(40, 40) self.relu2 = nn.ReLU() self.fc3 = nn.Linear(40, 1) def forward(self, x): x = self.fc1(x) x = self.relu1(x) x = self.fc2(x) x = self.relu2(x) x = self.fc3(x) return x # The prototype doesn't like F.mse_loss. def mse_loss(x, y): return torch.mean((x - y) ** 2) params, net, _ = make_functional(ThreeLayerNet().to(device)) K = 20 losses = [] num_tasks = 4 alpha = 0.1 def sample_tasks(outer_batch_size, inner_batch_size): # Select amplitude and phase for the task As = [] phases = [] for _ in range(outer_batch_size): As.append(np.random.uniform(low=0.1, high=.5)) phases.append(np.random.uniform(low=0., high=np.pi)) def get_batch(): xs, ys = [], [] for A, phase in zip(As, phases): x = np.random.uniform(low=-5., high=5., size=(inner_batch_size, 1)) y = A * np.sin(x + phase) xs.append(x) ys.append(y) return torch.tensor(xs, dtype=torch.float, device=device), \ torch.tensor(ys, dtype=torch.float, device=device) x1, y1 = get_batch() x2, y2 = get_batch() return x1, y1, x2, y2 def get_loss_for_task(use_transform, x1, y1, x2, y2): def inner_loss(params, x1, y1): f = net(params, (x1,)) loss = mse_loss(f, y1) return loss if use_transform: grads = grad(inner_loss)(params, x1, y1) else: loss = inner_loss(params, x1, y1) grads = torch.autograd.grad(loss, params, create_graph=True) new_params = [(params[i] - alpha*grads[i]) for i in range(len(params))] v_f = net(new_params, (x2,)) return mse_loss(v_f, y2) task = sample_tasks(num_tasks, K) # Compute with vmap+grad inner_losses = vmap(partial(get_loss_for_task, True))\ (task[0], task[1], task[2], task[3]) loss2 = sum(inner_losses)/len(inner_losses) result_grads = torch.autograd.grad(loss2, params) # Compute without vmap+grad inner_losses = [ get_loss_for_task(False, task[0][i], task[1][i], task[2][i], task[3][i]) for i in range(num_tasks) ] loss2 = sum(inner_losses)/len(inner_losses) expected_grads = torch.autograd.grad(loss2, params) self.assertEqual(result_grads, expected_grads)
targets = target.unsqueeze(0) predictions = fmodel(params, buffers, batch) loss = loss_fn(predictions, targets) return loss # Now, let's use ``grad`` to create a new function that computes the gradient # with respect to the first argument of compute_loss (i.e. the params). ft_compute_grad = grad(compute_loss) # ``ft_compute_grad`` computes the gradient for a single (sample, target) pair. # We can use ``vmap`` to get it to compute the gradient over an entire batch # of samples and targets. Note that in_dims=(None, None, 0, 0) because we wish # to map ``ft_compute_grad`` over the 0th dimension of the data and targets # and use the same params and buffers for each. ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0)) # Finally, let's used our transformed function to compute per-sample-gradients: ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets) for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads): assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=1e-6, rtol=1e-6) # A quick note: there are limitations around what types of functions can be # transformed by vmap. The best functions to transform are ones that are # pure functions: a function where the outputs are only determined by the inputs # that have no side effects (e.g. mutation). vmap is unable to handle mutation of # arbitrary Python data structures, but it is able to handle many in-place
def compute_jac(xp): jacobian_rows = [ torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0] for vec in unit_vectors ] return torch.stack(jacobian_rows) jacobian = compute_jac(xp) # Instead of computing the jacobian row-by-row, we can use ``vmap`` to get rid # of the for-loop and vectorize the computation. We can't directly apply vmap # to PyTorch Autograd; instead, functorch provides a ``vjp`` transform: from functorch import vmap, vjp _, vjp_fn = vjp(partial(predict, weight, bias), x) ft_jacobian, = vmap(vjp_fn)(unit_vectors) assert torch.allclose(ft_jacobian, jacobian) # In another tutorial a composition of reverse-mode AD and vmap gave us # per-sample-gradients. In this tutorial, composing reverse-mode AD and vmap # gives us Jacobian computation! Various compositions of vmap and autodiff # transforms can give us different interesting quantities. # # functorch provides ``jacrev`` as a convenience function that performs # the vmap-vjp composition to compute jacobians. ``jacrev`` accepts an argnums # argument that says which argument we would like to compute Jacobians with # respect to. from functorch import jacrev ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x) assert torch.allclose(ft_jacobian, jacobian)
def test_ensemble_regression(self, device): def make_spirals(n_samples, noise_std=0., rotations=1.): ts = torch.linspace(0, 1, n_samples) rs = ts ** 0.5 thetas = rs * rotations * 2 * math.pi signs = torch.randint(0, 2, (n_samples,)) * 2 - 1 labels = (signs > 0).to(torch.long) xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples) * noise_std ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples) * noise_std points = torch.stack([xs, ys], dim=1) return points.to(device), labels.to(device) points, labels = make_spirals(100, noise_std=0.05) class MLPClassifier(nn.Module): def __init__(self, hidden_dim=32, n_classes=2): super().__init__() self.hidden_dim = hidden_dim self.n_classes = n_classes self.fc1 = nn.Linear(2, self.hidden_dim) self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) def forward(self, x): x = self.fc1(x) x = F.relu(x) x = self.fc2(x) x = F.log_softmax(x, -1) return x loss_fn = nn.NLLLoss() weights, func_model, _ = make_functional(MLPClassifier().to(device)) def train_step_fn(use_transform, weights, batch, targets, lr=0.2): def compute_loss(weights, batch, targets): output = func_model(weights, (batch,)) loss = loss_fn(output, targets) return loss if use_transform: grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets) else: loss = compute_loss(weights, batch, targets) grad_weights = torch.autograd.grad(loss, weights) new_weights = [] with torch.no_grad(): for grad_weight, weight in zip(grad_weights, weights): new_weights.append(weight - grad_weight * lr) # NB: return looks weird because torch.vmap must return Tensors return (loss, *new_weights) def unpack(train_result): return train_result[0], train_result[1:] def init_fn(num_models): models = tuple(MLPClassifier().to(device) for _ in range(num_models)) weights = tuple(make_functional(model)[0] for model in models) weights = tuple(zip(*weights)) weights = tuple(torch.stack(shards).detach() for shards in weights) return weights def slice_weights(batched_weights, index): return tuple(weight[index].detach().requires_grad_() for weight in batched_weights) batched_weights = init_fn(num_models=2) parallel_train_step_fn = vmap(partial(train_step_fn, True), in_dims=(0, None, None)) result_loss, result_weights = unpack(parallel_train_step_fn(batched_weights, points, labels)) loss0, weights0 = unpack(train_step_fn(False, slice_weights(batched_weights, 0), points, labels)) loss1, weights1 = unpack(train_step_fn(False, slice_weights(batched_weights, 1), points, labels)) expected_loss = torch.stack([loss0, loss1]) expected_weights = tuple(torch.stack([w0, w1]) for w0, w1 in zip(weights0, weights1)) self.assertEqual(result_loss, expected_loss) self.assertEqual(result_weights, expected_weights)
def create_hamiltonian_wo_k(self, positions, species): """ Args: positions: particle position matrix 2D species: array of species shifts: uses 2D shifts matrix of comupte_shifts function kwargs: Dict of 2D matrix of slyer-koster parameters kwargs_diag: Dict of 2D matrix of one-site parameters kwargs_overlap: Dict of 2D matrix of off-site overlap parameters Returns: hamiltonian wo k as matrix """ if self.spd: n_orbitals = 9 get_hop_int = get_hop_int_spd else: n_orbitals = 4 get_hop_int = get_hop_int_sp # n_orbitals = 9 # 4 for sp, 9 for spd create_bondmatrix_mask = bondmatrix_masking(self.cutoff) shifted_positions = vmap(shift_fn, (0, None))(positions, self.shifts) shifted_pair_distance_vectors = shifted_positions.view(positions.shape[0], 1, self.shifts.shape[0], 3) -\ positions.view(1, positions.shape[0], 1, 3) # expand species shape to be the same as the shifted coordinates # DEVICE ???????????????? shifted_species = torch.repeat_interleave(torch.unsqueeze(species, axis=0), self.shifts.shape[0], dim=0).T # flatten first dimension for cartesian product species_b = torch.repeat_interleave(torch.unsqueeze(shifted_species, axis=1), species.shape[0], dim=1) species_a = torch.repeat_interleave(torch.unsqueeze(shifted_species, axis=1), species.shape[0], dim=1).permute(1, 0, 2) # separate into two vectors for particle pairs a,b and reshape to (particle number, particle_number, N_images) # DEVICE ???????????????? dir_cos = get_dir_cos( shifted_pair_distance_vectors ) # (particle number, particle number, N_images, 3) pair_distances = torch.linalg.norm( shifted_pair_distance_vectors, axis=-1) # (particle number, particle number, N_images, dim) bondmatrix = create_bondmatrix_mask(pair_distances) param_vec_1 = self.get_params(pair_distances, species_a, species_b, 'V') param_vec_2 = self.get_params(pair_distances, species_b, species_a, 'V') param_vec_1 = param_vec_1 * torch.unsqueeze(bondmatrix, axis=-1) param_vec_2 = param_vec_2 * torch.unsqueeze(bondmatrix, axis=-1) hamiltonian = get_hop_int( torch.cat([param_vec_1, dir_cos, param_vec_2], axis=-1)).permute(2, 3, 4, 0, 1) param_diag = self.get_params_diag(pair_distances, species_a, 'diag') hamiltonian = hamiltonian + param_diag overlap_vec_1 = self.get_params(pair_distances, species_a, species_b, 'S') overlap_vec_2 = self.get_params(pair_distances, species_b, species_a, 'S') overlap_vec_1 = overlap_vec_1 * torch.unsqueeze(bondmatrix, axis=-1) overlap_vec_2 = overlap_vec_2 * torch.unsqueeze(bondmatrix, axis=-1) overlap_matrix = get_hop_int( torch.cat([overlap_vec_1, dir_cos, overlap_vec_2], axis=-1)).permute(2, 3, 4, 0, 1) hamiltonian = torch.permute(hamiltonian, (0, 3, 1, 4, 2)) \ .reshape(species.shape[0] * n_orbitals, species.shape[0] * n_orbitals, self.shifts.shape[0]) overlap_matrix = torch.permute(overlap_matrix, (0, 3, 1, 4, 2)) \ .reshape(species.shape[0] * n_orbitals, species.shape[0] * n_orbitals, self.shifts.shape[0]) return hamiltonian, overlap_matrix
def test_vmap_grad(self, device): x = torch.randn(3, device=device) y = vmap(grad(torch.sin))(x) self.assertEqual(y, x.cos())
def test_vmap_on_jacrev_simple(self, device): x = torch.randn(2, 3, device=device) y = vmap(jacrev(torch.sin))(x) expected = torch.stack([torch.diagflat(x[i].cos()) for i in range(2)]) assert torch.allclose(y, expected)
def train(args, model, train_loader, optimizer, epoch, device): start_time = datetime.now() criterion = nn.CrossEntropyLoss() losses = [] top1_acc = [] for i, (images, target) in enumerate(tqdm(train_loader)): images = images.to(device) target = target.to(device) # Step 1: compute per-sample-grads # In order to use functional vmap+grad, we need to be able to # pass the weights to a model. func_model, weights = make_functional(model) # To use vmap+grad to compute per-sample-grads, the forward pass # must be re-formulated on a single example. # We use the `grad` operator to compute forward+backward on a single example, # and finally `vmap` to do forward+backward on multiple examples. def compute_loss_and_output(weights, image, target): images = image.unsqueeze(0) targets = target.unsqueeze(0) output = func_model(weights, images) loss = criterion(output, targets) return loss, output.squeeze(0) # `grad(f)` is a functional API that returns a function `f'` that # computes gradients by running both the forward and backward pass. # We want to extract some intermediate # values from the computation (i.e. the loss and output). # # To extract the loss, we use the `grad_and_value` API, that returns the # gradient of the weights w.r.t. the loss and the loss. # # To extract the output, we use the `has_aux=True` flag. # `has_aux=True` assumes that `f` returns a tuple of two values, # where the first is to be differentiated and the second "auxiliary value" # is not to be differentiated. `f'` returns the gradient w.r.t. the loss, # the loss, and the auxiliary value. grads_loss_output = grad_and_value(compute_loss_and_output, has_aux=True) sample_grads, (sample_loss, output) = \ vmap(grads_loss_output, (None, 0, 0))(weights, images, target) loss = sample_loss.mean() for grad_sample, weight in zip(sample_grads, model.parameters()): weight.grad_sample = grad_sample.detach() # Step 2: Clip the per-sample-grads, sum them to form grads, and add noise clip_and_accumulate_and_add_noise(model, args.max_per_sample_grad_norm, args.sigma) preds = np.argmax(output.detach().cpu().numpy(), axis=1) labels = target.detach().cpu().numpy() losses.append(loss.item()) # measure accuracy and record loss acc1 = accuracy(preds, labels) top1_acc.append(acc1) # make sure we take a step after processing the last mini-batch in the # epoch to ensure we start the next epoch with a clean state optimizer.step() optimizer.zero_grad() if i % args.print_freq == 0: print(f"\tTrain Epoch: {epoch} \t" f"Loss: {np.mean(losses):.6f} " f"Acc@1: {np.mean(top1_acc):.6f} ") train_duration = datetime.now() - start_time return train_duration
def test_resnet18_per_sample_grads(self, device): # Straight out of opacus def _replace_child( root: nn.Module, child_name: str, converter: Callable[[nn.Module], nn.Module] ) -> None: # find the immediate parent parent = root nameList = child_name.split(".") for name in nameList[:-1]: parent = parent._modules[name] # set to identity parent._modules[nameList[-1]] = converter(parent._modules[nameList[-1]]) def replace_all_modules( root: nn.Module, target_class: Type[nn.Module], converter: Callable[[nn.Module], nn.Module], ) -> nn.Module: # base case if isinstance(root, target_class): return converter(root) for name, obj in root.named_modules(): if isinstance(obj, target_class): _replace_child(root, name, converter) return root def _batchnorm_to_groupnorm(module: nn.modules.batchnorm._BatchNorm) -> nn.Module: return nn.GroupNorm(min(32, module.num_features), module.num_features, affine=True) def convert_batchnorm_modules( model: nn.Module, converter: Callable[ [nn.modules.batchnorm._BatchNorm], nn.Module ] = _batchnorm_to_groupnorm, ) -> nn.Module: return replace_all_modules(model, nn.modules.batchnorm._BatchNorm, converter) import torchvision.models as models model = convert_batchnorm_modules(models.resnet18(num_classes=10)).to(device) criterion = nn.CrossEntropyLoss() weights, func_model, descriptors = make_functional(model) def compute_loss(weights, image, target): images = image.unsqueeze(0) targets = target.unsqueeze(0) output = func_model(weights, (images,)) loss = criterion(output, targets) return loss batch_size = 3 images = torch.randn(batch_size, 3, 32, 32, device=device) targets = torch.randint(0, 10, (batch_size,), device=device) result_grads = vmap(grad(compute_loss), in_dims=(None, 0, 0))(weights, images, targets) expected_grads = [ torch.autograd.grad(compute_loss(weights, images[i], targets[i]), weights) for i in range(batch_size) ] expected_grads = [torch.stack(shards) for shards in zip(*expected_grads)] self.assertEqual(result_grads, expected_grads)
# stateless version of the model (fmodel) and stacked parameters and buffers. from functorch import combine_state_for_ensemble fmodel, params, buffers = combine_state_for_ensemble(models) [p.requires_grad_() for p in params] # Option 1: get predictions using a different minibatch for each model. # By default, vmap maps a function across the first dimension of all inputs to the # passed-in function. After `combine_state_for_ensemble`, each of of ``params``, # ``buffers`` have an additional dimension of size ``num_models`` at the front; # and ``minibatches`` has a dimension of size ``num_models``. print([p.size(0) for p in params]) assert minibatches.shape == (num_models, 64, 1, 28, 28) from functorch import vmap predictions1_vmap = vmap(fmodel)(params, buffers, minibatches) assert torch.allclose(predictions1_vmap, torch.stack(predictions1), atol=1e-6, rtol=1e-6) # Option 2: get predictions using the same minibatch of data # vmap has an in_dims arg that specify which dimensions to map over. # Using ``None``, we tell vmap we want the same minibatch to apply for all of # the 10 models. predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch) assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-6, rtol=1e-6)
def foo(x): y = vmap(torch.sin)(x) return y.sum()
def test_maml_omniglot(self, device): # TODO: there appears to be precision issues for float32 dtype = torch.double # TODO: The prototype doesn't support in-place relu (and some other # in-place operations. That can be fixed.) inplace_relu = False n_way = 5 n_inner_iter = 2 num_tasks = 2 class Flatten(nn.Module): def forward(self, input): return input.view(input.size(0), -1) net = nn.Sequential( nn.Conv2d(1, 64, 3), nn.BatchNorm2d(64, momentum=1, affine=True), nn.ReLU(inplace=inplace_relu), nn.MaxPool2d(2, 2), nn.Conv2d(64, 64, 3), nn.BatchNorm2d(64, momentum=1, affine=True), nn.ReLU(inplace=inplace_relu), nn.MaxPool2d(2, 2), nn.Conv2d(64, 64, 3), nn.BatchNorm2d(64, momentum=1, affine=True), nn.ReLU(inplace=inplace_relu), nn.MaxPool2d(2, 2), Flatten(), nn.Linear(64, n_way)).to(device).to(dtype) params, buffers, fnet, _, _, = make_functional_with_buffers(net) net = (params, buffers, fnet) def loss_for_task(net, n_inner_iter, use_transform, x_spt, y_spt, x_qry, y_qry): params, buffers, fnet = net querysz = x_qry.size(0) def compute_loss(new_params, buffers, x, y): logits = fnet(new_params, buffers, (x,)) loss = F.cross_entropy(logits, y) return loss new_params = params for _ in range(n_inner_iter): if use_transform: grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt) else: res = compute_loss(new_params, buffers, x_spt, y_spt) grads = torch.autograd.grad(res, new_params, create_graph=True) new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)] qry_logits = fnet(new_params, buffers, (x_qry,)) qry_loss = F.cross_entropy(qry_logits, y_qry) qry_acc = (qry_logits.argmax( dim=1) == y_qry).sum() / querysz return qry_loss, qry_acc # Get some sample inputs... x_spt = torch.randn(num_tasks, 25, 1, 28, 28, dtype=dtype, device=device) y_spt = torch.randint(0, 5, (num_tasks, 25), device=device) x_qry = torch.randn(num_tasks, 75, 1, 28, 28, dtype=dtype,device=device) y_qry = torch.randint(0, 5, (num_tasks, 75), device=device) # compute with vmap + grad compute_loss = partial(loss_for_task, net, n_inner_iter, True) qry_losses, _ = vmap(compute_loss)(x_spt, y_spt, x_qry, y_qry) result_grads = torch.autograd.grad(qry_losses.sum(), params) # compute without vmap + grad compute_loss = partial(loss_for_task, net, n_inner_iter, False) losses = [compute_loss(x_spt[i], y_spt[i], x_qry[i], y_qry[i])[0] for i in range(num_tasks)] expected_grads = torch.autograd.grad(sum(losses), params) self.assertEqual(result_grads, expected_grads)