def density(self, iter=100, n_v=1): """ compute estimated eigenvalue density using stochastic lanczos algorithm (SLQ) iter: number of iterations used to compute trace n_v: number of SLQ runs """ device = self.device eigen_list_full = [] weight_list_full = [] for k in range(n_v): v = [ torch.randint_like(p, high=2, device=device) for p in self.params ] # generate Rademacher random variables for v_i in v: v_i[v_i == 0] = -1 v = normalization(v) # standard lanczos algorithm initlization v_list = [v] w_list = [] alpha_list = [] beta_list = [] ############### Lanczos for i in range(iter): self.model.zero_grad() w_prime = [torch.zeros(p.size()).to(device) for p in self.params] if i == 0: if self.full_dataset: _, w_prime = self.dataloader_hv_product(v) else: w_prime = hessian_vector_product( self.gradsH, self.params, v) alpha = group_product(w_prime, v) alpha_list.append(alpha.cpu().item()) w = group_add(w_prime, v, alpha=-alpha) w_list.append(w) else: beta = torch.sqrt(group_product(w, w)) beta_list.append(beta.cpu().item()) if beta_list[-1] != 0.: # We should re-orth it v = orthnormal(w, v_list) v_list.append(v) else: # generate a new vector w = [torch.randn(p.size()).to(device) for p in self.params] v = orthnormal(w, v_list) v_list.append(v) if self.full_dataset: _, w_prime = self.dataloader_hv_product(v) else: w_prime = hessian_vector_product( self.gradsH, self.params, v) alpha = group_product(w_prime, v) alpha_list.append(alpha.cpu().item()) w_tmp = group_add(w_prime, v, alpha=-alpha) w = group_add(w_tmp, v_list[-2], alpha=-beta) T = torch.zeros(iter, iter).to(device) for i in range(len(alpha_list)): T[i, i] = alpha_list[i] if i < len(alpha_list) - 1: T[i + 1, i] = beta_list[i] T[i, i + 1] = beta_list[i] a_, b_ = torch.eig(T, eigenvectors=True) eigen_list = a_[:, 0] weight_list = b_[0, :]**2 eigen_list_full.append(list(eigen_list.cpu().numpy())) weight_list_full.append(list(weight_list.cpu().numpy())) return eigen_list_full, weight_list_full
def density(self, iter=100, n_v=1, debug=False): """ compute estimated eigenvalue density using stochastic lanczos algorithm (SLQ) iter: number of iterations used to compute trace n_v: number of SLQ runs """ device = self.device eigen_list_full = [] weight_list_full = [] # Prepare to record data if self.record_data: now = datetime.datetime.now() timestamp = "_{:02d}{:02d}_{:02d}{:02d}{:02d}".format( now.day, now.month, now.hour, now.minute, now.second) save_file = self.data_save_dir + "ESD" + timestamp + ".txt" start_time = time.time() for k in range(n_v): v = [ torch.randint_like(p, high=2, device=device) for p in self.params ] # generate Rademacher random variables for v_i in v: v_i[v_i == 0] = -1 v = normalization(v) # standard lanczos algorithm initlization v_list = [v] w_list = [] alpha_list = [] beta_list = [] ############### Lanczos for i in range(iter): if debug: print("Iteration {}".format(i)) self.model.zero_grad() w_prime = [ torch.zeros(p.size()).to(device) for p in self.params ] if i == 0: if self.full_dataset: _, w_prime = self.dataloader_hv_product(v) else: w_prime = hessian_vector_product( self.gradsH, self.params, v) alpha = group_product(w_prime, v) alpha_list.append(alpha.cpu().item()) w = group_add(w_prime, v, alpha=-alpha) w_list.append(w) else: beta = torch.sqrt(group_product(w, w)) beta_list.append(beta.cpu().item()) if beta_list[-1] != 0.: # We should re-orth it v = orthnormal(w, v_list) v_list.append(v) else: # generate a new vector w = [ torch.randn(p.size()).to(device) for p in self.params ] v = orthnormal(w, v_list) v_list.append(v) if self.full_dataset: _, w_prime = self.dataloader_hv_product(v) else: w_prime = hessian_vector_product( self.gradsH, self.params, v) alpha = group_product(w_prime, v) alpha_list.append(alpha.cpu().item()) w_tmp = group_add(w_prime, v, alpha=-alpha) w = group_add(w_tmp, v_list[-2], alpha=-beta) T = torch.zeros(iter, iter).to(device) for i in range(len(alpha_list)): T[i, i] = alpha_list[i] if i < len(alpha_list) - 1: T[i + 1, i] = beta_list[i] T[i, i + 1] = beta_list[i] a_, b_ = torch.eig(T, eigenvectors=True) eigen_list = a_[:, 0] weight_list = b_[0, :]**2 eigen_list_full.append(list(eigen_list.cpu().numpy())) weight_list_full.append(list(weight_list.cpu().numpy())) # Write data if applicable stop_time = time.time() if self.record_data: with open(save_file, 'w') as f: f.write("Total Elapsed Time(s)\n") f.write("{}\n".format(stop_time - start_time)) return eigen_list_full, weight_list_full