def compute_loss_surface_loader(model, loader, v1, v2,
                                loss=torch.nn.CrossEntropyLoss(),
                                n_batch=10, n_pts=50, range_=10.,
                               device=torch.device("cuda:0")):
    
    start_pars = model.state_dict()
    vec_len = torch.linspace(-range_.item(), range_.item(), n_pts)
    ## init loss surface and the vector multipliers ##
    loss_surf = torch.zeros(n_pts, n_pts)
    with torch.no_grad():
        ## loop and get loss at each point ##
        for ii in range(n_pts):
            for jj in range(n_pts):
                perturb = v1.mul(vec_len[ii]) + v2.mul(vec_len[jj])
                # print(perturb.shape)
                perturb = utils.unflatten_like(perturb.t(), model.parameters())
                for i, par in enumerate(model.parameters()):
                    par.data = par.data + perturb[i].to(par.device)
                    
                loss_surf[ii, jj] = compute_loader_loss(model, loader,
                                                        loss, n_batch,
                                                        device=device)

                model.load_state_dict(start_pars)

    X, Y = np.meshgrid(vec_len, vec_len)
    return X, Y, loss_surf
 def export_base_parameters(self, base_model, index):
     new_pars = self.full_parameters[:, index].unsqueeze(0)
     new_pars = utils.unflatten_like(new_pars, base_model)
     
     base_parameters = base_model.parameters()
     for parameter, base_parameter in zip(new_pars, base_parameters):
         base_parameter.data.copy_(parameter.data)
    def add_vert(self, to_simplexes=[0]):
        
        self.fix_points = [True] * self.n_vert + [False]
        new_model = self.architecture(self.n_output, 
                                      fix_points=self.fix_points,
                                      **self.architecture_kwargs)
        
        ## assign osld pars to new model ##
        for index in range(self.n_vert):
            old_parameters = list(self.net. parameters())[index::self.n_vert]
            new_parameters = list(new_model.parameters())[index::(self.n_vert+1)]
            for old_par, new_par in zip(old_parameters, new_parameters):
                new_par.data.copy_(old_par.data)
        
        new_parameters = list(new_model.parameters())
        new_parameters = new_parameters[(self.n_vert)::(self.n_vert+1)]
        n_par = sum([p.numel() for p in new_parameters])
        ## assign mean of old pars to new vertex ##
        par_vecs = torch.zeros(self.n_vert, n_par).to(new_parameters[0].device)
        for ii in range(self.n_vert):
            temp = [p for p in self.net.parameters()][ii::self.n_vert]
            par_vecs[ii, :] = utils.flatten(temp)

        center_pars = torch.mean(par_vecs,  0).unsqueeze(0)
        center_pars = utils.unflatten_like(center_pars, new_parameters)
        for cntr, par in zip(center_pars, new_parameters):
            par.data = cntr.to(par.device)
        
        ## update self values ##
        self.n_vert += 1
        self.net = new_model
        self.simplex_modules = []
        for module in self.net.modules():
            if issubclass(module.__class__, SimplexModule):
                self.simplex_modules.append(module)
        
        for cc in to_simplexes:
            self.simplicial_complex[cc].append(self.n_vert-1)
        
        return
 def assign_pars(self, pars):
     pars = utils.unflatten_like(pars, self.net.parameters())
     for old, new in zip(self.net.parameters(), pars):
         old.data.copy_(new.data)