def compute_loss_surface_loader(model, loader, v1, v2, loss=torch.nn.CrossEntropyLoss(), n_batch=10, n_pts=50, range_=10., device=torch.device("cuda:0")): start_pars = model.state_dict() vec_len = torch.linspace(-range_.item(), range_.item(), n_pts) ## init loss surface and the vector multipliers ## loss_surf = torch.zeros(n_pts, n_pts) with torch.no_grad(): ## loop and get loss at each point ## for ii in range(n_pts): for jj in range(n_pts): perturb = v1.mul(vec_len[ii]) + v2.mul(vec_len[jj]) # print(perturb.shape) perturb = utils.unflatten_like(perturb.t(), model.parameters()) for i, par in enumerate(model.parameters()): par.data = par.data + perturb[i].to(par.device) loss_surf[ii, jj] = compute_loader_loss(model, loader, loss, n_batch, device=device) model.load_state_dict(start_pars) X, Y = np.meshgrid(vec_len, vec_len) return X, Y, loss_surf
def export_base_parameters(self, base_model, index): new_pars = self.full_parameters[:, index].unsqueeze(0) new_pars = utils.unflatten_like(new_pars, base_model) base_parameters = base_model.parameters() for parameter, base_parameter in zip(new_pars, base_parameters): base_parameter.data.copy_(parameter.data)
def add_vert(self, to_simplexes=[0]): self.fix_points = [True] * self.n_vert + [False] new_model = self.architecture(self.n_output, fix_points=self.fix_points, **self.architecture_kwargs) ## assign osld pars to new model ## for index in range(self.n_vert): old_parameters = list(self.net. parameters())[index::self.n_vert] new_parameters = list(new_model.parameters())[index::(self.n_vert+1)] for old_par, new_par in zip(old_parameters, new_parameters): new_par.data.copy_(old_par.data) new_parameters = list(new_model.parameters()) new_parameters = new_parameters[(self.n_vert)::(self.n_vert+1)] n_par = sum([p.numel() for p in new_parameters]) ## assign mean of old pars to new vertex ## par_vecs = torch.zeros(self.n_vert, n_par).to(new_parameters[0].device) for ii in range(self.n_vert): temp = [p for p in self.net.parameters()][ii::self.n_vert] par_vecs[ii, :] = utils.flatten(temp) center_pars = torch.mean(par_vecs, 0).unsqueeze(0) center_pars = utils.unflatten_like(center_pars, new_parameters) for cntr, par in zip(center_pars, new_parameters): par.data = cntr.to(par.device) ## update self values ## self.n_vert += 1 self.net = new_model self.simplex_modules = [] for module in self.net.modules(): if issubclass(module.__class__, SimplexModule): self.simplex_modules.append(module) for cc in to_simplexes: self.simplicial_complex[cc].append(self.n_vert-1) return
def assign_pars(self, pars): pars = utils.unflatten_like(pars, self.net.parameters()) for old, new in zip(self.net.parameters(), pars): old.data.copy_(new.data)