def dp_sgd_backward(dis, real_imgs, fake_imgs, device, clip_norm, noise_factor, dis_opt): """ since only one part of the loss depends on the data, we compute its gradients separately and noise them up in order to maintain similar gradient sizes, gradients for the other loss are clipped to the same per sample norm """ # real data loss first: params = list(dis.parameters()) loss_real = -pt.mean(dis(real_imgs)) with backpack(BatchGrad(), BatchL2Grad()): loss_real.backward(retain_graph=True) squared_param_norms_real = [ p.batch_l2 for p in params ] # first we get all the squared parameter norms... global_norms_real = pt.sqrt( pt.sum(pt.stack(squared_param_norms_real), dim=0)) # ...then compute the global norms... global_clips_real = pt.clamp_max( clip_norm / global_norms_real, 1.) # ...and finally get a vector of clipping factors perturbed_grads = [] for idx, param in enumerate(params): clipped_sample_grads = param.grad_batch * expand_vector( global_clips_real, param.grad_batch) clipped_grad = pt.sum(clipped_sample_grads, dim=0) # after clipping we sum over the batch noise_sdev = noise_factor * 2 * clip_norm # gaussian noise standard dev is computed (sensitivity is 2*clip)... perturbed_grad = clipped_grad + pt.randn_like( clipped_grad, device=device) * noise_sdev # ...and applied perturbed_grads.append(perturbed_grad) # store perturbed grads dis_opt.zero_grad() # now add fake data loss gradients: loss_fake = pt.mean(dis(fake_imgs)) with backpack(BatchGrad(), BatchL2Grad()): loss_fake.backward() squared_param_norms_fake = [ p.batch_l2 for p in params ] # first we get all the squared parameter norms... global_norms_fake = pt.sqrt( pt.sum(pt.stack(squared_param_norms_fake), dim=0)) # ...then compute the global norms... global_clips_fake = pt.clamp_max( clip_norm / global_norms_fake, 1.) # ...and finally get a vector of clipping factors for idx, param in enumerate(params): clipped_sample_grads = param.grad_batch * expand_vector( global_clips_fake, param.grad_batch) clipped_grad = pt.sum(clipped_sample_grads, dim=0) # after clipping we sum over the batch param.grad = clipped_grad + perturbed_grads[idx] ld = loss_real.item() + loss_fake.item() return global_norms_real, global_clips_real, global_norms_fake, global_clips_fake, ld
def test_extension_hook_param_before_savefield_exists(problem): """Extension hooks iterating over parameters may get called before BackPACK. This leads to the case, that the BackPACK quantities might not be calculated yet. Thus, derived quantities cannot be calculated. Sequential containers just work fine. Custom containers crash. Args: problem: problem consisting of model, loss, and problem_string Raises: NotImplementedError: if problem_string is unknown """ _, loss, problem_string = problem params_without_grad_batch = [] def check_grad_batch(module): """Check whether the module has a grad_batch attribute. Args: module: the module to check Raises: AssertionError: if a parameter does not have grad_batch attribute. """ for p in module.parameters(): if not hasattr(p, "grad_batch"): params_without_grad_batch.append(id(p)) raise AssertionError( f"Param {id(p)} has no 'grad_batch' attribute") if problem_string == NESTED_SEQUENTIAL: with backpack(BatchGrad(), extension_hook=check_grad_batch, debug=True): loss.backward() assert len(params_without_grad_batch) == 0 elif problem_string == CUSTOM_CONTAINER: with raises(AssertionError): with backpack(BatchGrad(), extension_hook=check_grad_batch, debug=True): loss.backward() assert len(params_without_grad_batch) > 0 else: raise NotImplementedError(f"unknown problem_string={problem_string}")
def test_for_loop_replace() -> None: """Application of retain_graph: replace an outer for-loop. This test is based on issue #220 opened by Romain3Ch216. It computes per-component individual gradients of a tensor-valued output with a for loop over components, rather than over samples and components. """ manual_seed(0) B = 5 M = 3 h = 2 x = randn(B, h) fc = extend(Linear(h, M)) A = fc(x) grad_autograd = zeros(B, M, *fc.weight.shape) for b in range(B): for m in range(M): with backpack(retain_graph=True): grads = autograd.grad(A[b, m], fc.weight, retain_graph=True) grad_autograd[b, m] = grads[0] grad_backpack = zeros(B, M, *fc.weight.shape) for i in range(M): with backpack(BatchGrad(), retain_graph=True): A[:, i].backward(ones_like(A[:, i]), retain_graph=True) grad_backpack[:, i] = fc.weight.grad_batch check_sizes_and_values(grad_backpack, grad_autograd)
def select(self, context): tensor = torch.from_numpy(context).float().cuda() mu = self.func(tensor) # calculate gradient sum_mu = torch.sum(mu) with backpack(BatchGrad()): sum_mu.backward() g_list = torch.cat( [ p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters() ], dim=1, ) # calculate CB sigma = torch.sqrt( torch.diag( torch.matmul(torch.matmul(g_list, self.Uinv), torch.transpose(g_list, 0, 1)))) # calculate UCB and select the arm score = torch.normal(mu.view(-1), self.nu * sigma.view(-1)) arm = torch.argmax(score) # update self.Uinv g = g_list[arm] self.Uinv = self.Uinv - (torch.matmul( torch.matmul(torch.matmul(self.Uinv, g.view(-1, 1)), g.view( 1, -1)), self.Uinv)) / (1 + torch.matmul( torch.matmul(g.view(1, -1), self.Uinv), g.view(-1, 1))) return arm, [], [], []
def update(self, rollouts, explore_policy, exploration): obs_shape = rollouts.obs.size()[2:] action_shape = rollouts.actions.size()[-1] num_steps, num_processes, _ = rollouts.rewards.size() qs = self.model(rollouts.obs.view(-1, *obs_shape)).view( num_steps + 1, num_processes, -1) values = qs[:-1].gather(-1, rollouts.actions).view(num_steps, num_processes, 1) probs, _ = explore_policy(qs[-1].detach(), exploration) next_values = (probs * qs[-1]).sum(-1).unsqueeze(-1).unsqueeze(0) advantages = rollouts.returns[:-1] - values with backpack(BatchGrad()): self.optimizer.zero_grad() torch.cat([values, next_values], dim=0).sum().backward() # store the td errors and masks to use inside tdprop self.optimizer.temp_store_td_errors(advantages.detach()) self.optimizer.temp_store_masks(rollouts.cumul_masks) # extract grads: grad = -2 * td * grad_v self.optimizer.extract_grads_from_batch() total_norm = nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.clip_coef = self.max_grad_norm / (total_norm + 1e-6) self.optimizer.step() return advantages.pow(2).mean().item()
def self_interferences(loss, params, t=False): with backpack(BatchGrad()): sumloss(loss).backward() grads = torch.cat([i.grad_batch.reshape((loss.shape[0], -1)) for i in params], 1) if t: return (grads ** 2).sum(1) return (grads ** 2).sum(1).cpu().data.numpy()
def cross_interferences_diag(lossA, lossB, params, t=False): """lossA and B must be callbacks, because for backpack to work it seems there must only be one computation graph associated with the parameters at a time! Or like it checks the most recent one only? I'm confused """ with backpack(BatchGrad()): sumloss(lossA()).backward() gradsA = torch.cat([i.grad_batch.reshape((i.grad_batch.shape[0], -1)) for i in params], 1) with backpack(BatchGrad()): sumloss(lossB()).backward() gradsB = torch.cat([i.grad_batch.reshape((i.grad_batch.shape[0], -1)) for i in params], 1) x = (gradsA * gradsB).sum(1) gc.collect() if t: return x return x.cpu().data.numpy()
def interferences(loss, params): with backpack(BatchGrad()): sumloss(loss).backward() grads = torch.cat([i.grad_batch.reshape((loss.shape[0], -1)) for i in params], 1) all_dots = [] for i in range(len(grads)-1): all_dots.append((grads[i][None, :] @ grads[i+1:].t())[0]) return torch.cat(all_dots, 0).cpu().data.numpy()
def get_anchor_gradients(self, net, loss_func): """Get the n x p matrix of gradients based on public data.""" public_inputs, public_targets = self.public_inputs, self.public_targets outputs = net(public_inputs) loss = loss_func(outputs, public_targets) with backpack(BatchGrad()): loss.backward() cur_batch_grad_list = [] for p in net.parameters(): cur_batch_grad_list.append(p.grad_batch) del p.grad_batch return flatten_tensor(cur_batch_grad_list) # n x p
def _select(self, context): tensor = torch.from_numpy(context).float().to(self.device) mu = self.model(tensor) sum_mu = torch.sum(mu) with backpack(BatchGrad()): sum_mu.backward() g_list = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.model.parameters()], dim=1) sigma = torch.sqrt(torch.sum(g_list * g_list / self.A, dim=1)) sample_r = torch.normal(mu.view(-1), self.alpha * sigma.view(-1)) arm = torch.argmax(sample_r) self.A += g_list[arm] * g_list[arm] return arm, g_list[arm].norm().item(), 0, 0
def get_anchor_gradients(self, net, loss_func): public_inputs, public_targets = self.public_inputs, self.public_targets outputs = net(public_inputs) loss = loss_func(outputs, public_targets) with backpack(BatchGrad()): loss.backward() cur_batch_grad_list = [] for p in net.parameters(): cur_batch_grad_list.append( p.grad_batch.reshape(p.grad_batch.shape[0], -1)) del p.grad_batch return flatten_tensor(cur_batch_grad_list)
def select(self, context): tensor = torch.from_numpy(context).float().cuda() mu = self.func(tensor) sum_mu = torch.sum(mu) with backpack(BatchGrad()): sum_mu.backward() g_list = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1) sigma = torch.sqrt(torch.sum(self.lamdba * self.nu * g_list * g_list / self.U, dim=1)) if self.style == 'ts': sample_r = torch.normal(mu.view(-1), sigma.view(-1)) elif self.style == 'ucb': sample_r = mu.view(-1) + sigma.view(-1) arm = torch.argmax(sample_r) self.U += g_list[arm] * g_list[arm] return arm, g_list[arm].norm().item(), 0, 0
def extensions(self, global_step): """Return list of BackPACK extensions required for the computation. Args: global_step (int): The current iteration number. Returns: list: (Potentially empty) list with required BackPACK quantities. """ ext = [] if self.should_compute(global_step): ext.append(BatchGrad()) return ext
def select(self, context): tensor = torch.from_numpy(context).float().cuda() mu = self.func(tensor) sum_mu = torch.sum(mu) with backpack(BatchGrad()): sum_mu.backward() g_list = torch.cat([ p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters() ], dim=1) sigma = torch.sqrt(torch.sum(g_list * g_list / self.U, dim=1)) sample_r = mu.view(-1) + self.nu * sigma.view(-1) arm = torch.argmax(sample_r) self.U += g_list[arm] * g_list[arm] # print("selected arm: {} {}".format(mu.view(-1)[arm], sigma[arm])) return arm, g_list[arm].norm().item(), 0, 0
def _train_batch_(self): (data, target) = self.train_loader_it.next() data, target = data.to(self.device), target.to(self.device) self.current_batch_size = data.size()[0] output = self.model(data) loss = self.loss_function(output, target) with backpack(BatchGrad()): loss.mean().backward() loss_value = loss.mean() self.loss_batch = loss self.current_training_loss = torch.unsqueeze(loss_value.detach(), dim=0) self.train_batch_index += 1 self._current_validation_loss.calculated = False
def decide(self, pool_articles, userID, k=1): self.a = len(pool_articles) user_vec = torch.cat(self.a*[self.user_feature[userID - 1].view(1, -1)]) article_vec = torch.cat([ torch.from_numpy(x.contextFeatureVector[:self.dimension]).view(1, -1).to(torch.float32) for x in pool_articles]) score = self.learner(user_vec, article_vec.cuda()).view(-1) sum_score = torch.sum(score) with backpack(BatchGrad()): sum_score.backward() grad = torch.cat([p.grad_batch.view(self.a, -1) for p in self.learner.parameters()], dim=1) sigma = torch.sqrt(torch.sum(grad * grad / self.U, dim=1)) self.reg = self.nu * torch.mean(sigma).item() arm = torch.argmax(score + self.nu * sigma).item() self.g = grad[arm] return [pool_articles[arm]]
def test_retain_graph(): """Tests whether retain_graph works as expected. Does several forward and backward passes. In between, it is tested whether BackPACK quantities are present or not. """ manual_seed(0) model = extend(Sequential(Linear(4, 6), Linear(6, 5))) loss_fn = extend(CrossEntropyLoss()) # after a forward pass graph is not clear inputs = rand(8, 4) labels = randint(5, (8, )) loss = loss_fn(model(inputs), labels) with raises(AssertionError): _check_no_io(model) # after a normal backward pass graph should be clear loss.backward() _check_no_io(model) # after a backward pass with retain_graph=True graph is not clear loss = loss_fn(model(inputs), labels) with backpack(retain_graph=True): loss.backward(retain_graph=True) with raises(AssertionError): _check_no_io(model) # doing several backward passes with retain_graph=True for _ in range(3): with backpack(retain_graph=True): loss.backward(retain_graph=True) with raises(AssertionError): _check_no_io(model) # finally doing a normal backward pass that verifies graph is clear again with backpack(BatchGrad()): loss.backward() _check_no_io(model)
def main(args): print(args) assert args.dpsgd torch.backends.cudnn.benchmark = True train_data, train_labels = get_data(args) model = model_dict[args.experiment](vocab_size=args.max_features).cuda() model = extend(model) optimizer = DP_SGD(model.parameters(), lr=args.learning_rate, **dpsgd_kwargs[args.experiment]) loss_function = nn.CrossEntropyLoss( ) if args.experiment != 'logreg' else nn.BCELoss() timings = [] for epoch in range(1, args.epochs + 1): start = time.perf_counter() dataloader = data.dataloader(train_data, train_labels, args.batch_size) for batch_idx, (x, y) in enumerate(dataloader): x, y = x.cuda(non_blocking=True), y.cuda(non_blocking=True) model.zero_grad() outputs = model(x) loss = loss_function(outputs, y) with backpack(BatchGrad(), BatchL2Grad()): loss.backward() optimizer.step() torch.cuda.synchronize() duration = time.perf_counter() - start print("Time Taken for Epoch: ", duration) timings.append(duration) if not args.no_save: utils.save_runtimes(__file__.split('.')[0], args, timings) else: print('Not saving!') print('Done!')
def backward_and_step(self, vt, vtp, delta, gamma_k): """Performs backward pass and step for n-step TD(0), does not work for TD(lambda) Requires: - v: V(s_t) - vp: V(s_t+n) - delta: the TD delta, i.e. (v - (r + gamma*vp)) for TD(0), (v - (G^5 + gamma^5*vp)) for 5s-TD careful to profide (v - y) rather than (y - v) - gamma_k: the multiplicative factor on vp applied in delta, e.g. gamma^5 for 5-step TD, can be 0 e.g. if s_t+n is terminal """ self.zero_grad() with backpack(BatchGrad()): torch.cat([vt, vtp], dim=0).sum().backward() mbs = vt.shape[0] batch_g = self.batch_p2v([i.grad_batch.data for i in self.parameters], mbs * 2) batch_dvt = batch_g[:mbs] batch_dvtp = batch_g[mbs:] * gamma_k[:, None] # derivative of delta^2 batch_dL = 2 * delta[:, None] * batch_dvt dL = batch_dL.mean(0) if self.weight_decay: # We ignore weight decay when computing the correction # This should not be problematic but further testing required dL.add_(self.p2v(self.parameters).detach(), alpha=self.weight_decay) self._step += 1 bias_correction1 = 1 - self.beta1**self._step if self.mom_correct_bias else 1 bias_correction2 = 1 - self.beta2**self._step # \/V(s) - \/V(s') gdiff = batch_dvt - batch_dvtp if self.do_correction: # last corrected momentum mu_tm1 = self.mu - self.eta # Update eta if self.diagonal: z = (gdiff * batch_dvt).mean(0) #z = (gdiff.mean(0) * batch_dvt.mean(0)) self.eta.mul_(self.beta).add_(self.alpha * self.beta * (self.zeta * mu_tm1)) # eta elif self.block_size > 0: pad = torch.zeros((self.nparams_padding), device=self.device) batch_pad = torch.zeros((mbs, self.nparams_padding), device=self.device) batch_shape = mbs, self.nblocks, self.block_size # batch_block_shape = mbs, self.nblocks, self.block_size, self.block_size gdiff_padded = torch.cat([gdiff, batch_pad], 1).reshape(batch_shape) batch_dvt_padded = torch.cat([batch_dvt, batch_pad], 1).reshape(batch_shape) z = torch.einsum('ija,ijb->jab', gdiff_padded, batch_dvt_padded) / mbs mu_padded = torch.cat([mu_tm1, pad], 0).reshape( (self.nblocks, self.block_size)) zeta_T_times_mu = (torch.einsum('ikj,ik->ij', self.zeta, mu_padded).reshape( (-1, ))[:self.nparams] ) # unpad self.eta.mul_(self.beta).add_(self.alpha * self.beta * zeta_T_times_mu) # eta else: z = torch.einsum('ij,ik->jk', gdiff, batch_dvt) / mbs # sum of outer product #z = torch.einsum('j,k->jk', gdiff.mean(0), batch_dvt.mean(0)) self.eta.mul_(self.beta).add_(self.alpha * self.beta * (self.zeta.T @ mu_tm1)) # eta # Update zeta self.zeta.mul_(self.beta).add_(z, alpha=1 - self.dampening) # Update momentum mu self.mu.mul_(self.beta).add_(dL, alpha=1 - self.dampening) # Compute (corrected) momentum \mu_t mu_t = self.mu - self.eta if self.do_correction else self.mu # TDProp update if self.beta2 > 0: # diag[(\/V(s) - \/V(s')) ^T \/V(s)] diag_H = (2 * gdiff * batch_dvt).pow(2).mean(0) # Update TDprop denominator self.z_denom.mul_(self.beta2).add_(1 - self.beta2, diag_H) # Compute bias corrected TDprop denom denom = (self.z_denom.sqrt() / np.sqrt(bias_correction2)).add_( self.epsilon) # Update parameters for p, g, d in zip(self.parameters, self.v2p(mu_t), self.v2p(denom)): p.data.addcdiv_(g, d, value=-(self.alpha / bias_correction1)) # Normal or corrected momentum update else: for p, g in zip(self.parameters, self.v2p(mu_t)): p.data.add_(g, alpha=-(self.alpha / bias_correction1))
def train(epoch): print('\nEpoch: %d' % epoch) net.train() train_loss = 0 correct = 0 total = 0 t0 = time.time() steps = n_training//args.batchsize if(train_samples == None): # using pytorch data loader for CIFAR10 loader = iter(trainloader) else: # manually sample minibatchs for SVHN sample_idxes = np.arange(n_training) np.random.shuffle(sample_idxes) is_first = 0 for batch_idx in range(steps): if(args.dataset=='svhn'): current_batch_idxes = sample_idxes[batch_idx*args.batchsize : (batch_idx+1)*args.batchsize] inputs, targets = train_samples[current_batch_idxes], train_labels[current_batch_idxes] else: inputs, targets = next(loader) if use_cuda: inputs, targets = inputs.cuda(), targets.cuda() if(args.private): logging = batch_idx % 20 == 0 ## compute anchor subspace optimizer.zero_grad() if is_first ==0: net.gep.get_anchor_space(net, loss_func=loss_func, logging=logging) is_first+=1 ## collect batch gradients batch_grad_list = [] optimizer.zero_grad() outputs = net(inputs) loss = loss_func(outputs, targets) with backpack(BatchGrad()): loss.backward() for p in net.parameters(): batch_grad_list.append(p.grad_batch.reshape(p.grad_batch.shape[0], -1)) del p.grad_batch ## compute gradient embeddings and residual gradients clipped_theta, residual_grad, target_grad = net.gep(flatten_tensor(batch_grad_list), logging = logging) ## add noise to guarantee differential privacy theta_noise = torch.normal(0, noise_multiplier0*args.clip0/args.batchsize, size=clipped_theta.shape, device=clipped_theta.device) grad_noise = torch.normal(0, noise_multiplier1*args.clip1/args.batchsize, size=residual_grad.shape, device=residual_grad.device) clipped_theta += theta_noise residual_grad += grad_noise ## update with Biased-GEP or GEP if(args.rgp): noisy_grad = gep.get_approx_grad(clipped_theta) + residual_grad else: noisy_grad = gep.get_approx_grad(clipped_theta) if(logging): print('target grad norm: %.2f, noisy approximation norm: %.2f'%(target_grad.norm().item(), noisy_grad.norm().item())) ## make use of noisy gradients offset = 0 for p in net.parameters(): shape = p.grad.shape numel = p.grad.numel() p.grad.data = noisy_grad[offset:offset+numel].view(shape) #+ 0.1*torch.mean(pub_grad, dim=0).view(shape) offset+=numel else: optimizer.zero_grad() outputs = net(inputs) loss = loss_func(outputs, targets) loss.backward() try: for p in net.parameters(): del p.grad_batch except: pass optimizer.step() step_loss = loss.item() if(args.private): step_loss /= inputs.shape[0] train_loss += step_loss _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += predicted.eq(targets.data).float().cpu().sum() acc = 100.*float(correct)/float(total) t1 = time.time() print('Train loss:%.5f'%(train_loss/(batch_idx+1)), 'time: %d s'%(t1-t0), 'train acc:', acc, '\n') return (train_loss/batch_idx, acc)
model = extend(MyFirstResNet()).to(DEVICE) # %% # Using :py:class:`BatchGrad <backpack.extensions.BatchGrad>` in a # :py:class:`with backpack(...) <backpack.backpack>` block, # we can access the individual gradients for each sample. # # The loss does not need to be extended in this case either, as it does not # have model parameters and BackPACK does not need to know about it for # first order extensions. This also means you can use any custom loss function. model.zero_grad() loss = F.cross_entropy(model(x), y, reduction="sum") with backpack(BatchGrad()): loss.backward() print("{:<20} {:<30} {:<30}".format("Param", "grad", "grad (batch)")) print("-" * 80) for name, p in model.named_parameters(): print( "{:<20}: {:<30} {:<30}".format(name, str(p.grad.shape), str(p.grad_batch.shape)) ) # %% # To check that everything works, let's compute one individual gradient with # PyTorch (using a single sample in a forward and backward pass) # and compare it with the one computed by BackPACK. sample_to_check = 1
lossfunc = extend(lossfunc) # %% # Individual gradients for a mini-batch subset # -------------------------------------------- # # Let's say we only want to compute individual gradients for samples 0, 1, # 13, and 42. Naively, we could perform the computation for all samples, then # slice out the samples we care about. # selected samples subsampling = [0, 1, 13, 42] loss = lossfunc(model(X), y) with backpack(BatchGrad()): loss.backward() # naive approach: compute for all, slice out relevant naive = [p.grad_batch[subsampling] for p in model.parameters()] # %% # This is not efficient, as individual gradients are computed for all samples, # most of them being discarded after. We can do better by specifying the active # samples directly with the ``subsampling`` argument of # :py:class:`BatchGrad <backpack.extensions.BatchGrad>`. loss = lossfunc(model(X), y) # efficient approach: specify active samples in backward pass with backpack(BatchGrad(subsampling=subsampling)):
print("\tmodule.input0.shape:", module.input0.shape) # gradient w.r.t output print("\tg_out[0].shape: ", g_out[0].shape) # actual computation return (g_out[0] * module.input0).flatten(start_dim=1).sum(axis=1).unsqueeze(-1) # %% # Lastly, we need to register the mapping between layer (``ScaleModule``) and layer # extension (``ScaleModuleBatchGrad``) in an instance of # :py:class:`BatchGrad <backpack.extensions.BatchGrad>`. # register module-computation mapping extension = BatchGrad() extension.set_module_extension(ScaleModule, ScaleModuleBatchGrad()) # %% # That's it. We can now pass ``extension`` to a # :py:class:`with backpack(...) <backpack.backpack>` context and compute individual # gradients with respect to ``ScaleModule``'s ``weight`` parameter. # %% # Test custom module # ------------------ # Here, we verify the custom module extension on a small net with random inputs. # Let's create these. batch_size = 10 batch_axis = 0
x = self.linear1(x) x = self.linear2(x) return x model = _MyCustomModule() else: raise NotImplementedError( f"problem={problem_string} but no test setting for this.") model = extend(model.to(device)) lossfunc = extend(CrossEntropyLoss(reduction="mean").to(device)) loss = lossfunc(model(X), y) yield model, loss, problem_string @mark.parametrize("extension", [BatchGrad(), DiagGGNExact()], ids=["BatchGrad", "DiagGGNExact"]) def test_extension_hook_multiple_parameter_visits( problem, extension: BackpropExtension): """Tests whether each parameter is visited exactly once. For those cases where parameters are visited more than once (e.g. Custom containers), it tests that an error is raised. Furthermore, it is tested whether first order extensions run fine in either case, and second order extensions raise an error in the case of custom containers. Args: problem: test problem, consisting of model, loss, and problem_string extension: first or second order extension to test
def train(epoch): torch.set_printoptions(precision=16) print('\nEpoch: %d' % epoch) net.train() train_loss = 0 correct = 0 total = 0 step_st_time = time.time() epoch_time = 0 print('\nKFAC/KBFGS damping: %f' % damping) print('\nNGD damping: %f' % (damping)) # desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % (tag, lr_scheduler.get_last_lr()[0], 0, 0, correct, total)) writer.add_scalar('train/lr', lr_scheduler.get_last_lr()[0], epoch) prog_bar = tqdm(enumerate(trainloader), total=len(trainloader), desc=desc, leave=True) for batch_idx, (inputs, targets) in prog_bar: if optim_name in ['kfac', 'skfac', 'ekfac', 'sgd', 'adam']: inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) if optim_name in ['kfac', 'skfac', 'ekfac'] and optimizer.steps % optimizer.TCov == 0: # compute true fisher optimizer.acc_stats = True with torch.no_grad(): sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1),1).squeeze().to(args.device) loss_sample = criterion(outputs, sampled_y) loss_sample.backward(retain_graph=True) optimizer.acc_stats = False optimizer.zero_grad() # clear the gradient for computing true-fisher. loss.backward() optimizer.step() elif optim_name in ['kbfgs', 'kbfgsl', 'kbfgsl_2loop', 'kbfgsl_mem_eff']: inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() outputs = net.forward(inputs) loss = criterion(outputs, targets) loss.backward() # do another forward-backward pass over batch inside step() def closure(): return inputs, targets, criterion, False # is_autoencoder = False optimizer.step(closure) elif optim_name == 'exact_ngd': inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) # update Fisher inverse if batch_idx % args.freq == 0: # compute true fisher with torch.no_grad(): sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs.cpu().data, dim=1),1).squeeze().to(args.device) # use backpack extension to compute individual gradient in a batch batch_grad = [] with backpack(BatchGrad()): loss_sample = criterion(outputs, sampled_y) loss_sample.backward(retain_graph=True) for name, param in net.named_parameters(): if hasattr(param, "grad_batch"): batch_grad.append(args.batch_size * param.grad_batch.reshape(args.batch_size, -1)) else: raise NotImplementedError J = torch.cat(batch_grad, 1) fisher = torch.matmul(J.t(), J) / args.batch_size inv = torch.linalg.inv(fisher + damping * torch.eye(fisher.size(0)).to(fisher.device)) # clean the gradient to compute the true fisher optimizer.zero_grad() loss.backward() # compute the step direction p = F^-1 @ g grad_list = [] for name, param in net.named_parameters(): grad_list.append(param.grad.data.reshape(-1, 1)) g = torch.cat(grad_list, 0) p = torch.matmul(inv, g) start = 0 for name, param in net.named_parameters(): end = start + param.data.reshape(-1, 1).size(0) param.grad.copy_(p[start:end].reshape(param.grad.data.shape)) start = end optimizer.step() ### new optimizer test elif optim_name in ['kngd'] : inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) if optimizer.steps % optimizer.freq == 0: # compute true fisher optimizer.acc_stats = True with torch.no_grad(): sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device) loss_sample = criterion(outputs, sampled_y) loss_sample.backward(retain_graph=True) optimizer.acc_stats = False optimizer.zero_grad() # clear the gradient for computing true-fisher. if args.partial_backprop == 'true': idx = (sampled_y == targets) == False loss = criterion(outputs[idx,:], targets[idx]) # print('extra:', idx.sum().item()) loss.backward() optimizer.step() elif optim_name == 'ngd': if batch_idx % args.freq == 0: store_io_(True) inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() # net.set_require_grad(True) outputs = net(inputs) damp = damping loss = criterion(outputs, targets) loss.backward(retain_graph=True) # storing original gradient for later use grad_org = [] # grad_dict = {} for name, param in net.named_parameters(): grad_org.append(param.grad.reshape(1, -1)) # grad_dict[name] = param.grad.clone() grad_org = torch.cat(grad_org, 1) ###### now we have to compute the true fisher with torch.no_grad(): # gg = torch.nn.functional.softmax(outputs, dim=1) sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device) if args.trial == 'true': update_list, loss = optimal_JJT_v2(outputs, sampled_y, args.batch_size, damping=damp, alpha=0.95, low_rank=args.low_rank, gamma=args.gamma, memory_efficient=args.memory_efficient, super_opt=args.super_opt) else: update_list, loss = optimal_JJT(outputs, sampled_y, args.batch_size, damping=damp, alpha=0.95, low_rank=args.low_rank, gamma=args.gamma, memory_efficient=args.memory_efficient) # optimizer.zero_grad() # update_list, loss = optimal_JJT_fused(outputs, sampled_y, args.batch_size, damping=damp) optimizer.zero_grad() # last part of SMW formula grad_new = [] for name, param in net.named_parameters(): param.grad.copy_(update_list[name]) grad_new.append(param.grad.reshape(1, -1)) grad_new = torch.cat(grad_new, 1) # grad_new = grad_org store_io_(False) else: inputs, targets = inputs.to(args.device), targets.to(args.device) optimizer.zero_grad() # net.set_require_grad(True) outputs = net(inputs) damp = damping loss = criterion(outputs, targets) loss.backward() # storing original gradient for later use grad_org = [] # grad_dict = {} for name, param in net.named_parameters(): grad_org.append(param.grad.reshape(1, -1)) # grad_dict[name] = param.grad.clone() grad_org = torch.cat(grad_org, 1) ###### now we have to compute the true fisher # with torch.no_grad(): # gg = torch.nn.functional.softmax(outputs, dim=1) # sampled_y = torch.multinomial(torch.nn.functional.softmax(outputs, dim=1),1).squeeze().to(args.device) all_modules = net.modules() for m in net.modules(): if hasattr(m, "NGD_inv"): grad = m.weight.grad if isinstance(m, nn.Linear): I = m.I G = m.G n = I.shape[0] NGD_inv = m.NGD_inv grad_prod = einsum("ni,oi->no", (I, grad)) grad_prod = einsum("no,no->n", (grad_prod, G)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = einsum("n,no->no", (v, G)) gv = einsum("no,ni->oi", (gv, I)) gv = gv / n update = (grad - gv)/damp m.weight.grad.copy_(update) elif isinstance(m, nn.Conv2d): if hasattr(m, "AX"): if args.low_rank.lower() == 'true': ###### using low rank structure U = m.U S = m.S V = m.V NGD_inv = m.NGD_inv n = NGD_inv.shape[0] grad_reshape = grad.reshape(grad.shape[0], -1) grad_prod = V @ grad_reshape.t().reshape(-1, 1) grad_prod = torch.diag(S) @ grad_prod grad_prod = U @ grad_prod grad_prod = grad_prod.squeeze() v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = U.t() @ v.unsqueeze(1) gv = torch.diag(S) @ gv gv = V.t() @ gv gv = gv.reshape(grad_reshape.shape[1], grad_reshape.shape[0]).t() gv = gv.view_as(grad) gv = gv / n update = (grad - gv)/damp m.weight.grad.copy_(update) else: AX = m.AX NGD_inv = m.NGD_inv n = AX.shape[0] grad_reshape = grad.reshape(grad.shape[0], -1) grad_prod = einsum("nkm,mk->n", (AX, grad_reshape)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = einsum("nkm,n->mk", (AX, v)) gv = gv.view_as(grad) gv = gv / n update = (grad - gv)/damp m.weight.grad.copy_(update) elif hasattr(m, "I"): I = m.I if args.memory_efficient == 'true': I = unfold_func(m)(I) G = m.G n = I.shape[0] NGD_inv = m.NGD_inv grad_reshape = grad.reshape(grad.shape[0], -1) x1 = einsum("nkl,mk->nml", (I, grad_reshape)) grad_prod = einsum("nml,nml->n", (x1, G)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = einsum("n,nml->nml", (v, G)) gv = einsum("nml,nkl->mk", (gv, I)) gv = gv.view_as(grad) gv = gv / n update = (grad - gv)/damp m.weight.grad.copy_(update) elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): if args.batchnorm == 'true': dw = m.dw n = dw.shape[0] NGD_inv = m.NGD_inv grad_prod = einsum("ni,i->n", (dw, grad)) v = matmul(NGD_inv, grad_prod.unsqueeze(1)).squeeze() gv = einsum("n,ni->i", (v, dw)) gv = gv / n update = (grad - gv)/damp m.weight.grad.copy_(update) # last part of SMW formula grad_new = [] for name, param in net.named_parameters(): grad_new.append(param.grad.reshape(1, -1)) grad_new = torch.cat(grad_new, 1) # grad_new = grad_org ##### do kl clip lr = lr_scheduler.get_last_lr()[0] # vg_sum = 0 # vg_sum += (grad_new * grad_org ).sum() # vg_sum = vg_sum * (lr ** 2) # nu = min(1.0, math.sqrt(args.kl_clip / vg_sum)) # for name, param in net.named_parameters(): # param.grad.mul_(nu) # optimizer.step() # manual optimizing: with torch.no_grad(): for name, param in net.named_parameters(): d_p = param.grad.data # print('=== step ===') # apply momentum # if args.momentum != 0: # buf[name].mul_(args.momentum).add_(d_p) # d_p.copy_(buf[name]) # apply weight decay if args.weight_decay != 0: d_p.add_(args.weight_decay, param.data) lr = lr_scheduler.get_last_lr()[0] param.data.add_(-lr, d_p) # print('d_p:', d_p.shape) # print(d_p) train_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() desc = ('[%s][LR=%s] Loss: %.3f | Acc: %.3f%% (%d/%d)' % (tag, lr_scheduler.get_last_lr()[0], train_loss / (batch_idx + 1), 100. * correct / total, correct, total)) prog_bar.set_description(desc, refresh=True) if args.step_info == 'true' and (batch_idx % 50 == 0 or batch_idx == len(prog_bar) - 1): step_saved_time = time.time() - step_st_time epoch_time += step_saved_time test_acc, test_loss = test(epoch) TRAIN_INFO['train_acc'].append(float("{:.4f}".format(100. * correct / total))) TRAIN_INFO['test_acc'].append(float("{:.4f}".format(test_acc))) TRAIN_INFO['train_loss'].append(float("{:.4f}".format(train_loss/(batch_idx + 1)))) TRAIN_INFO['test_loss'].append(float("{:.4f}".format(test_loss))) TRAIN_INFO['total_time'].append(float("{:.4f}".format(step_saved_time))) if args.debug_mem == 'true': TRAIN_INFO['memory'].append(torch.cuda.memory_reserved()) step_st_time = time.time() net.train() writer.add_scalar('train/loss', train_loss/(batch_idx + 1), epoch) writer.add_scalar('train/acc', 100. * correct / total, epoch) acc = 100. * correct / total train_loss = train_loss/(batch_idx + 1) if args.step_info == 'true': TRAIN_INFO['epoch_time'].append(float("{:.4f}".format(epoch_time))) # save diagonal blocks of exact Fisher inverse or its approximations if args.save_inv == 'true': all_modules = net.modules() count = 0 start, end = 0, 0 if optim_name == 'ngd': for m in all_modules: if m.__class__.__name__ == 'Linear': with torch.no_grad(): I = m.I G = m.G J = torch.einsum('ni,no->nio', I, G) J = J.reshape(J.size(0), -1) JTDJ = torch.matmul(J.t(), torch.matmul(m.NGD_inv, J)) / args.batch_size with open('ngd/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f: np.save(f, ((torch.eye(JTDJ.size(0)).to(JTDJ.device) - JTDJ) / damping).cpu().numpy()) count += 1 elif m.__class__.__name__ == 'Conv2d': with torch.no_grad(): AX = m.AX AX = AX.reshape(AX.size(0), -1) JTDJ = torch.matmul(AX.t(), torch.matmul(m.NGD_inv, AX)) / args.batch_size with open('ngd/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f: np.save(f, ((torch.eye(JTDJ.size(0)).to(JTDJ.device) - JTDJ) / damping).cpu().numpy()) count += 1 elif optim_name == 'exact_ngd': for m in all_modules: if m.__class__.__name__ in ['Conv2d', 'Linear']: with open('exact/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f: end = start + m.weight.data.reshape(1, -1).size(1) np.save(f, inv[start:end,start:end].cpu().numpy()) start = end + m.bias.data.size(0) count += 1 elif optim_name == 'kfac': for m in all_modules: if m.__class__.__name__ in ['Conv2d', 'Linear']: with open('kfac/' + str(epoch) + '_m_' + str(count) + '_inv.npy', 'wb') as f: G = optimizer.m_gg[m] A = optimizer.m_aa[m] H_g = torch.linalg.inv(G + math.sqrt(damping) * torch.eye(G.size(0)).to(G.device)) H_a = torch.linalg.inv(A + math.sqrt(damping) * torch.eye(A.size(0)).to(A.device)) end = m.weight.data.reshape(1, -1).size(1) kfac_inv = torch.kron(H_a, H_g)[:end,:end] np.save(f, kfac_inv.cpu().numpy()) count += 1 return acc, train_loss
model = Sequential(Flatten(), Linear(784, 10)) lossfunc = CrossEntropyLoss() model = extend(model) lossfunc = extend(lossfunc) # %% # First order extensions # ---------------------- # %% # Batch gradients loss = lossfunc(model(X), y) with backpack(BatchGrad()): loss.backward() for name, param in model.named_parameters(): print(name) print(".grad.shape: ", param.grad.shape) print(".grad_batch.shape: ", param.grad_batch.shape) # %% # Variance loss = lossfunc(model(X), y) with backpack(Variance()): loss.backward() for name, param in model.named_parameters():
tproblem = extend_with_access_unreduced_loss(tproblem) # forward pass batch_loss, _ = tproblem.get_batch_loss_and_accuracy() # individual loss savefield = "_unreduced_loss" individual_loss = getattr(batch_loss, savefield) print("Individual loss shape: ", individual_loss.shape) print("Mini-batch loss: ", batch_loss) print("Averaged individual loss:", individual_loss.mean()) # It is still possible to use BackPACK in the backward pass with backpack( BatchGrad(), Variance(), SumGradSquared(), BatchL2Grad(), DiagGGNExact(), DiagGGNMC(), KFAC(), KFLR(), KFRA(), DiagHessian(), ): batch_loss.backward() # print info for name, param in tproblem.net.named_parameters(): print(name)
# Computing clipped individual gradients # ----------------------------------------------------------------- # # Before writing the optimizer class, let's see how we can use ``BackPACK`` # on a single batch to compute the clipped gradients, without the overhead # of the optimizer class. # # We take a single batch from the data loader, compute the loss, # and use the ``with(backpack(...))`` syntax to activate two extensions; # ``BatchGrad`` and ``BatchL2Grad``. x, y = next(iter(mnist_dataloader)) x, y = x.to(DEVICE), y.to(DEVICE) loss = loss_function(model(x), y) with backpack(BatchL2Grad(), BatchGrad()): loss.backward() # %% # ``BatchGrad`` computes individual gradients and ``BatchL2Grad`` their norm (squared), # which get stored in the ``grad_batch`` and ``batch_l2`` attributes of the parameters for p in model.parameters(): print("{:28} {:32} {}".format(str(p.grad.shape), str(p.grad_batch.shape), str(p.batch_l2.shape))) # %% # To compute the clipped gradients, we need to know the norms of the complete # individual gradients, but ad the moment they are split across parameters, # so let's reduce over the parameters
def run_exp(meta_seed, nhid, nlayers, n_train_seeds): torch.manual_seed(meta_seed) np.random.seed(meta_seed) gamma = 0.9 def init_weights(m): if isinstance(m, torch.nn.Linear): k = np.sqrt(6 / (np.sum(m.weight.shape))) m.weight.data.uniform_(-k, k) m.bias.data.fill_(0) if isinstance(m, torch.nn.Conv2d): u, v, w, h = m.weight.shape k = np.sqrt(6 / (w * h * u + w * h * v)) m.weight.data.uniform_(-k, k) m.bias.data.fill_(0) env = CifarWindowEnv(_train_x, _train_y) test_env = CifarWindowEnvBatch() env.step_reward = 0.05 test_env.step_reward = 0.05 ##nhid = 32 act = torch.nn.LeakyReLU() #act = torch.nn.Tanh() model = torch.nn.Sequential(*([ torch.nn.Conv2d(4, nhid, 5, stride=2), act, torch.nn.Conv2d(nhid, nhid * 2, 3), act, torch.nn.Conv2d(nhid * 2, nhid * 4, 3), act ] + sum([[torch.nn.Conv2d(nhid * 4, nhid * 4, 3, padding=1), act] for i in range(nlayers)], []) + [ torch.nn.Flatten(), torch.nn.Linear(nhid * 4 * 10 * 10, nhid * 4), act, torch.nn.Linear(nhid * 4, 14) ])) model.to(device) model.apply(init_weights) if 1: model = extend(model) opt = torch.optim.Adam(model.parameters(), 1e-5) #, weight_decay=1e-5) #opt = torch.optim.SGD(model.parameters(), 1e-4, momentum=0.9) def run_test(X, Y, dataacc=None): obs = test_env.reset(X, Y) accr = np.zeros(len(X)) for i in range(test_env.max_steps): o = model(obs) cls_act = Categorical(logits=o[:, :10]).sample() mov_act = Categorical(logits=o[:, 10:]).sample() actions = torch.stack([cls_act, mov_act]).data.cpu().numpy() obs, r, done, _ = test_env.step(actions) accr += r if dataacc is not None: dataacc.append(obs[np.random.randint(0, len(obs))]) dataacc.append(obs[np.random.randint(0, len(obs))]) if done.all(): break return test_env.correct_answers / len(X), test_env.acc_reward train_perf = [] test_perf = [] all_dots = [] all_dots_test = [] tds = [] qs = [] xent = torch.nn.CrossEntropyLoss() tau = 0.1 n_rp = 1000 rp_s = torch.zeros((n_rp, 4, 32, 32), device=device) rp_a = torch.zeros((n_rp, 2), device=device, dtype=torch.long) rp_g = torch.zeros((n_rp, ), device=device) rp_idx = 0 rp_fill = 0 obs = env.reset(np.random.randint(0, n_train_seeds)) ntest = 128 epsilon = 0.9 ep_reward = 0 ep_rewards = [] ep_start = 0 for i in tqdm(range(200001)): epsilon = 0.9 * (1 - min(i, 100000) / 100000) + 0.05 if not i % 1000: t0 = time.time() dataacc = [] with torch.no_grad(): train_perf.append( run_test(_train_x[:min(ntest, n_train_seeds)], _train_y[:min(ntest, n_train_seeds)])) test_perf.append( run_test(test_x[:ntest], test_y[:ntest], dataacc=dataacc)) print(train_perf[-2:], test_perf[-2:], np.mean(ep_rewards[-50:]), len(ep_rewards)) if 1: t1 = time.time() s = rp_s[:128] if 1: loss = sumloss(model(s).mean()) with backpack(BatchGrad()): loss.backward() all_grads = torch.cat([ i.grad_batch.reshape((s.shape[0], -1)) for i in model.parameters() ], 1) else: all_grads = [] for k in range(len(s)): Qsa = model(s[k][None, :]) grads = torch.autograd.grad(Qsa.max(), model.parameters()) fg = torch.cat([i.reshape((-1, )) for i in grads]) all_grads.append(fg) dots = [] for k in range(len(s)): for j in range(k + 1, len(s)): dots.append(all_grads[k].dot(all_grads[j]).item()) all_dots.append(np.float32(dots).mean()) opt.zero_grad() s = torch.stack(dataacc[:128]) loss = sumloss(model(s).mean()) with backpack(BatchGrad()): loss.backward() all_grads = torch.cat([ i.grad_batch.reshape((s.shape[0], -1)) for i in model.parameters() ], 1) dots = [] for k in range(len(s)): for j in range(k + 1, len(s)): dots.append(all_grads[k].dot(all_grads[j]).item()) all_dots_test.append(np.float32(dots).mean()) opt.zero_grad() if i and 0: print(i, (cls_pi * torch.log(cls_pi)).mean().item(), (mov_pi * torch.log(mov_pi)).mean().item()) print(cls_pi[0].data.cpu().numpy()) o = model(obs[None, :]) cls_act = Categorical(logits=o[0, :10]).sample().item() #cls_act = env.current_y mov_act = Categorical(logits=o[0, 10:]).sample().item() action = np.int32([cls_act, mov_act]) #if np.random.uniform(0,1) < 0.4: # action = env.current_y obsp, r, done, _ = env.step(action) rp_s[rp_idx] = obs rp_a[rp_idx] = torch.tensor(action) rp_idx = (rp_idx + 1) % rp_s.shape[0] rp_fill += 1 ep_reward += r obs = obsp if done: rp_g[ep_start:i] = ep_reward ep_rewards.append(ep_reward) ep_reward = 0 ep_start = i obs = env.reset(np.random.randint(0, n_train_seeds)) if rp_idx > 250: rp_at = rp_idx rp_idx = 0 rp_fill = 0 s = rp_s[:rp_at] a = rp_a[:rp_at] g = rp_g[:rp_at] o = model(s) cls_pi = F.softmax(o[:, :10], 1).clamp(min=1e-5) mov_pi = F.softmax(o[:, 10:], 1).clamp(min=1e-5) cls_prob = torch.log(cls_pi[torch.arange(len(a)), a[:, 0]]) mov_prob = torch.log(mov_pi[torch.arange(len(a)), a[:, 1]]) #import pdb; pdb.set_trace() loss = -(g * (cls_prob + mov_prob)).mean() loss += 1e-4 * ((cls_pi * torch.log(cls_pi)).sum(1).mean() + (mov_pi * torch.log(mov_pi)).sum(1).mean()) loss.backward() #for p in model.parameters(): # p.grad.data.clamp_(-1, 1) opt.step() opt.zero_grad() s = rp_s[:200] all_grads = [] for i in range(len(s)): Qsa = model(s[i][None, :]) grads = torch.autograd.grad(Qsa.max(), model.parameters()) fg = torch.cat([i.reshape((-1, )) for i in grads], 0) all_grads.append(fg) dots = [] cosd = [] for i in range(len(s)): for j in range(i + 1, len(s)): dots.append(all_grads[i].dot(all_grads[j]).item()) cosd.append(all_grads[i].dot(all_grads[j]).item() / (np.sqrt( (all_grads[i]**2).sum().item()) * np.sqrt( (all_grads[j]**2).sum().item()))) dots = np.float32(dots) print(np.mean(train_perf[-5:]), np.mean(test_perf[-5:])) #, all_dots[-5:]) return { 'dots': dots, 'cosd': cosd, 'all_dots': all_dots, 'all_dots_test': all_dots_test, 'train_perf': train_perf, 'test_perf': test_perf, }
def run_exp(meta_seed, nhid, nlayers, n_train_seeds): torch.manual_seed(meta_seed) np.random.seed(meta_seed) gamma = 0.9 train_x = _train_x[:n_train_seeds] train_y = _train_y[:n_train_seeds] ##nhid = 32 act = torch.nn.LeakyReLU() #act = torch.nn.Tanh() model = torch.nn.Sequential(*([ torch.nn.Conv2d(3, nhid, 5, stride=2), act, torch.nn.Conv2d(nhid, nhid * 2, 3), act, torch.nn.Conv2d(nhid * 2, nhid * 4, 3), act ] + sum([[torch.nn.Conv2d(nhid * 4, nhid * 4, 3, padding=1), act] for i in range(nlayers)], []) + [ torch.nn.Flatten(), torch.nn.Linear(nhid * 4 * 10 * 10, nhid), act, torch.nn.Linear(nhid, 10) ])) def init_weights(m): if isinstance(m, torch.nn.Linear): k = np.sqrt(6 / (np.sum(m.weight.shape))) m.weight.data.uniform_(-k, k) m.bias.data.fill_(0) model.to(device) model.apply(init_weights) if 0: model = extend(model) opt = torch.optim.Adam(model.parameters(), 1e-3) #, weight_decay=1e-5) def compute_acc_mb(X, Y, mbs=1024): N = len(X) i = 0 tot = 0 while i < N: x = X[i:i + mbs] y = Y[i:i + mbs] tot += np.sum( model(x).argmax(1).cpu().data.numpy() == y.cpu().data.numpy()) i += mbs return tot / N train_perf = [] test_perf = [] all_dots = [] xent = torch.nn.CrossEntropyLoss() for i in tqdm(range(1000)): if not i % 20: train_perf.append(compute_acc_mb(train_x, train_y)) test_perf.append(compute_acc_mb(test_x, test_y)) s = train_x[:96] if 0: loss = sumloss(model(s).max(1).values) with backpack(BatchGrad()): loss.backward() all_grads = torch.cat([ i.grad_batch.reshape((mbs, -1)) for i in model.parameters() ], 1) all_grads = [] for i in range(len(s)): Qsa = model(s[i][None, :]) grads = torch.autograd.grad(Qsa.max(), model.parameters()) fg = torch.cat([i.reshape((-1, )) for i in grads]) all_grads.append(fg) dots = [] cosd = [] cosd = [] for i in range(len(s)): for j in range(i + 1, len(s)): dots.append(all_grads[i].dot(all_grads[j]).item()) all_dots.append(np.float32(dots).mean()) opt.zero_grad() mbidx = np.random.randint(0, len(train_x), 32) x = train_x[mbidx] y = train_y[mbidx] pred = model(x) loss = xent(pred, y) loss.backward() opt.step() opt.zero_grad() s = train_x[:200] all_grads = [] for i in range(len(s)): Qsa = model(s[i][None, :]) grads = torch.autograd.grad(Qsa.max(), model.parameters()) fg = torch.cat([i.reshape((-1, )) for i in grads], 0) all_grads.append(fg) dots = [] cosd = [] for i in range(len(s)): for j in range(i + 1, len(s)): dots.append(all_grads[i].dot(all_grads[j]).item()) cosd.append(all_grads[i].dot(all_grads[j]).item() / (np.sqrt( (all_grads[i]**2).sum().item()) * np.sqrt( (all_grads[j]**2).sum().item()))) dots = np.float32(dots) print(train_perf[-5:], test_perf[-5:], all_dots[-5:]) return { 'dots': dots, 'cosd': cosd, 'all_dots': all_dots, 'train_perf': train_perf, 'test_perf': test_perf, }