def _local_search(self, prediction, sat_problem, batch_replication): "Implements the Walk-SAT algorithm for post-processing." assignment = (prediction[0] > 0.5).float() assignment = sat_problem._active_variables * (2 * assignment - 1.0) sat_problem._edge_mask = torch.mm(sat_problem._graph_mask_tuple[1], sat_problem._active_variables) * \ torch.mm(sat_problem._graph_mask_tuple[3], sat_problem._active_functions) for _ in range(self._local_search_iterations): unsat_examples, unsat_functions = self._compute_energy( assignment, sat_problem) unsat_examples = (unsat_examples > 0).float() if batch_replication > 1: compact_unsat_examples = 1 - (torch.mm( sat_problem._replication_mask_tuple[1], 1 - unsat_examples) > 0).float() if compact_unsat_examples.sum() == 0: break elif unsat_examples.sum() == 0: break delta_energy = self._compute_energy_diff(assignment, sat_problem) max_delta_ind = util.sparse_argmax( -delta_energy.squeeze(1), sat_problem._batch_mask_tuple[0], device=self._device) unsat_variables = torch.mm( sat_problem._vf_mask_tuple[0], unsat_functions) * sat_problem._active_variables unsat_variables = (unsat_variables > 0).float() * torch.rand( [sat_problem._variable_num, 1], device=self._device) random_ind = util.sparse_argmax(unsat_variables.squeeze(1), sat_problem._batch_mask_tuple[0], device=self._device) coin = (torch.rand(sat_problem._batch_size, device=self._device) > self._epsilon).long() max_ind = coin * max_delta_ind + (1 - coin) * random_ind max_ind = max_ind[unsat_examples[:, 0] > 0] # Flipping the selected variables assignment[max_ind, 0] = -assignment[max_ind, 0] return (assignment + 1) / 2.0, prediction[1]
def forward(self, init_state, message_state, sat_problem, is_training, active_mask=None): if self._counters is None: self._counters = torch.zeros(sat_problem._batch_size, 1, device=self._device) if active_mask is not None: survey = message_state[1][:, 0].unsqueeze(1) survey = util.sparse_smooth_max(survey, sat_problem._graph_mask_tuple[0], self._device) survey = survey * sat_problem._active_variables survey = util.sparse_max(survey.squeeze(1), sat_problem._batch_mask_tuple[0], self._device).unsqueeze(1) active_mask[survey <= 1e-10] = 0 if self._previous_function_state is not None and sat_problem._active_variables.sum() > 0: function_diff = (self._previous_function_state - message_state[1][:, 0]).abs().unsqueeze(1) if sat_problem._edge_mask is not None: function_diff = function_diff * sat_problem._edge_mask sum_diff = util.sparse_smooth_max(function_diff, sat_problem._graph_mask_tuple[0], self._device) sum_diff = sum_diff * sat_problem._active_variables sum_diff = util.sparse_max(sum_diff.squeeze(1), sat_problem._batch_mask_tuple[0], self._device).unsqueeze(1) self._counters[sum_diff[:, 0] < self._tolerance, 0] = 0 sum_diff = (sum_diff < self._tolerance).float() sum_diff[self._counters[:, 0] >= self._t_max, 0] = 1 self._counters[self._counters[:, 0] >= self._t_max, 0] = 0 sum_diff = torch.mm(sat_problem._batch_mask_tuple[0], sum_diff) if sum_diff.sum() > 0: score, _ = self._scorer(message_state, sat_problem) # Find the variable index with max score for each instance in the batch coeff = score.abs() * sat_problem._active_variables * sum_diff if coeff.sum() > 0: max_ind = util.sparse_argmax(coeff.squeeze(1), sat_problem._batch_mask_tuple[0], self._device) norm = torch.mm(sat_problem._batch_mask_tuple[1], coeff) if active_mask is not None: max_ind = max_ind[(active_mask * (norm != 0)).squeeze(1)] else: max_ind = max_ind[norm.squeeze(1) != 0] if max_ind.size()[0] > 0: assignment = torch.zeros(sat_problem._variable_num, 1, device=self._device) assignment[max_ind, 0] = score.sign()[max_ind, 0] sat_problem.set_variables(assignment) self._counters = self._counters + 1 self._previous_function_state = message_state[1][:, 0] return message_state
def _deduplicate(self, prediction, propagator_state, decimator_state, sat_problem): "De-duplicates the current batch (to neutralize the batch replication) by finding the replica with minimum energy for each problem instance. " if sat_problem._batch_replication <= 1 or sat_problem._replication_mask_tuple is None: return None, None, None assignment = 2 * prediction[0] - 1.0 energy, _ = self._compute_energy(assignment, sat_problem) max_ind = util.sparse_argmax(-energy.squeeze(1), sat_problem._replication_mask_tuple[0], device=self._device) batch_flag = torch.zeros(sat_problem._batch_size, 1, device=self._device) batch_flag[max_ind, 0] = 1 flag = torch.mm(sat_problem._batch_mask_tuple[0], batch_flag) variable_prediction = (flag * prediction[0]).view( sat_problem._batch_replication, -1).sum(dim=0).unsqueeze(1) flag = torch.mm(sat_problem._graph_mask_tuple[1], flag) new_propagator_state = () for x in propagator_state: new_propagator_state += ((flag * x).view( sat_problem._batch_replication, int(sat_problem._edge_num / sat_problem._batch_replication), -1).sum(dim=0), ) new_decimator_state = () for x in decimator_state: new_decimator_state += ((flag * x).view( sat_problem._batch_replication, int(sat_problem._edge_num / sat_problem._batch_replication), -1).sum(dim=0), ) function_prediction = None if prediction[1] is not None: flag = torch.mm(sat_problem._batch_mask_tuple[2], batch_flag) function_prediction = (flag * prediction[1]).view( sat_problem._batch_replication, -1).sum(dim=0).unsqueeze(1) return (variable_prediction, function_prediction), new_propagator_state, new_decimator_state