def objective(self, batch_idx=None): """Evaluate the objective function Parameters ---------- batch_idx : `list` [`int`], optional Indices of batch. The default is `None` to evaluate over the full dataset. The evaluation is done in batches size according to ``batch_size`` to avoid loading the full dataset to the memory. Returns ------- obj_value : `float` Objective value. """ if batch_idx is not None: obj_value = self.obj_function(batch_idx).item() else: obj_value = 0 for batch_start, batch_end in _batches(self.data_size, self.batch_size): obj_value += self.obj_function(range( batch_start, batch_end)).item() * ( batch_end - batch_start) / self.data_size return obj_value
def _primal(self, problem): primal_value_est = 0 primal_grad_norm_est = 0 if self.settings['shuffle']: idx_epoch = np.random.permutation(np.arange(problem.data_size)) else: idx_epoch = range(0, problem.data_size) for batch_start, batch_end in _batches(problem.data_size, self.settings['batch_size']): batch_idx = idx_epoch[batch_start:batch_end] self.primal_solver.zero_grad() _, obj_value, _, _ = problem.lagrangian(batch_idx) self.primal_solver.step() with torch.no_grad(): primal_value_est += obj_value * ( batch_end - batch_start) / problem.data_size primal_grad_norm_est += np.sum([ p.grad.norm().item()**2 for p in problem.model.parameters ]) * (batch_end - batch_start) / problem.data_size return primal_value_est, primal_grad_norm_est
def lagrangian(self, batch_idx=None): """Evaluate Lagrangian (and its gradient) Parameters ---------- batch_idx : `list` [`int`], optional Indices of batch. The default is `None` to evaluate over the full dataset. The evaluation is done in batches size according to ``batch_size`` to avoid loading the full dataset to the memory. Returns ------- L : `float` Lagrangian value. obj_value : `float` Objective value. constraints_slacks : `list` [`torch.tensor`, (1, )] Slacks of average constraints pointwise_slacks : `list` [`torch.tensor`, (``len(batch_idx)``, )] Slacks of pointwise constraints """ if batch_idx is not None: L, obj_value, constraint_slacks, pointwise_slacks = self._lagrangian( batch_idx) else: # Initialization L = 0 obj_value = 0 constraint_slacks = [0] * len(self.constraints) pointwise_slacks = [torch.zeros([0])] * len(self.pointwise) # Compute over the whole data set in batches for batch_start, batch_end in _batches(self.data_size, self.batch_size): L_batch, obj_value_batch, constraint_slacks_batch, pointwise_slacks_batch = self._lagrangian( np.arange(batch_start, batch_end)) L += L_batch * (batch_end - batch_start) / self.data_size obj_value += obj_value_batch * (batch_end - batch_start) / self.data_size for ii, slack in enumerate(constraint_slacks_batch): constraint_slacks[ii] += slack * ( batch_end - batch_start) / self.data_size for ii, slack in enumerate(pointwise_slacks_batch): pointwise_slacks[ii] = torch.cat( (pointwise_slacks[ii], slack)) return L, obj_value, constraint_slacks, pointwise_slacks
def slacks(self, batch_idx=None): """Evaluate constraint slacks Parameters ---------- batch_idx : `list` [`int`], optional Indices of batch. The default is `None` to evaluate over the full dataset. The evaluation is done in batches size according to ``batch_size`` to avoid loading the full dataset to the memory. Returns ------- constraint_slacks : `list` [`float`] Constraint violation of the average constraints. pointwise_slacks : `list` [`torch.tensor`, (``len(batch_idx)``, )] Constraint violation of the pointwise constraints. """ if batch_idx is not None: constraint_slacks = self._constraint_slacks(batch_idx) pointwise_slacks = self._pointwise_slacks(batch_idx) else: constraint_slacks = [0] * len(self.constraints) pointwise_slacks = [torch.zeros([0])] * len(self.pointwise) for batch_start, batch_end in _batches(self.data_size, self.batch_size): for ii, s in enumerate( self._constraint_slacks(range(batch_start, batch_end))): constraint_slacks[ii] += s * (batch_end - batch_start) / self.data_size for ii, s in enumerate( self._pointwise_slacks(range(batch_start, batch_end))): pointwise_slacks[ii] = torch.cat((pointwise_slacks[ii], s)) return constraint_slacks, pointwise_slacks
def primal_dual_update(self, problem): # Initialize estimates primal_value_est = 0 primal_grad_norm_est = 0 if self.state_dict['HAS_CONSTRAINTS']: constraint_slacks_est = [torch.tensor(0, dtype = torch.float, requires_grad = False, device = self.settings['device']) \ for _ in problem.rhs] pointwise_slacks_est = [torch.zeros_like(rhs, dtype = torch.float, requires_grad = False, device = self.settings['device']) \ for rhs in problem.pointwise_rhs] dual_grad_norm_est = 0 else: constraint_slacks_est, pointwise_slacks_est, dual_grad_norm_est = None, None, None # Shuffle dataset if self.settings['shuffle']: idx_epoch = np.random.permutation(np.arange(problem.data_size)) else: idx_epoch = range(0, problem.data_size) ### START OF EPOCH ### for batch_start, batch_end in _batches(problem.data_size, self.settings['batch_size']): batch_idx = idx_epoch[batch_start:batch_end] ### PRIMAL UPDATE ### # Gradient step self.primal_solver.zero_grad() _, obj_value, constraint_slacks, pointwise_slacks = problem.lagrangian( batch_idx) self.primal_solver.step() # Compute primal quantities estimates with torch.no_grad(): primal_value_est += obj_value * ( batch_end - batch_start) / problem.data_size primal_grad_norm_est += np.sum([ p.grad.norm().item()**2 for p in problem.model.parameters ]) * (batch_end - batch_start) / problem.data_size ### DUAL UPDATE ### if self.state_dict['HAS_CONSTRAINTS']: # Set gradients for ii, slack in enumerate(constraint_slacks): problem.lambdas[ii].grad = -slack constraint_slacks_est[ii] += slack * ( batch_end - batch_start) / problem.data_size if problem.lambdas[ii] > 0 or (problem.lambdas[ii] == 0 and slack > 0): dual_grad_norm_est += slack**2 * ( batch_end - batch_start) / problem.data_size for ii, slack in enumerate(pointwise_slacks): expanded_slack = torch.zeros_like(problem.mus[ii]) expanded_slack[batch_idx] = slack problem.mus[ii].grad = -expanded_slack pointwise_slacks_est[ii][batch_idx] = slack inactive = torch.logical_or(problem.mus[ii][batch_idx] > 0, \ torch.logical_and(problem.mus[ii][batch_idx] == 0, slack > 0)) dual_grad_norm_est += torch.norm(slack[inactive]).item()**2 # Gradient gradient step self.dual_solver.step() # Project onto non-negative orthant for ii, _ in enumerate(problem.lambdas): problem.lambdas[ii][problem.lambdas[ii] < 0] = 0 for ii, _ in enumerate(problem.mus): problem.mus[ii][problem.mus[ii] < 0] = 0 return primal_value_est, primal_grad_norm_est, constraint_slacks_est, pointwise_slacks_est, dual_grad_norm_est