예제 #1
0
    def _local_search(self, prediction, sat_problem, batch_replication):
        "Implements the Walk-SAT algorithm for post-processing."

        assignment = (prediction[0] > 0.5).float()
        assignment = sat_problem._active_variables * (2 * assignment - 1.0)

        sat_problem._edge_mask = torch.mm(sat_problem._graph_mask_tuple[1], sat_problem._active_variables) * \
            torch.mm(sat_problem._graph_mask_tuple[3], sat_problem._active_functions)

        for _ in range(self._local_search_iterations):
            unsat_examples, unsat_functions = self._compute_energy(
                assignment, sat_problem)
            unsat_examples = (unsat_examples > 0).float()

            if batch_replication > 1:
                compact_unsat_examples = 1 - (torch.mm(
                    sat_problem._replication_mask_tuple[1], 1 - unsat_examples)
                                              > 0).float()
                if compact_unsat_examples.sum() == 0:
                    break
            elif unsat_examples.sum() == 0:
                break

            delta_energy = self._compute_energy_diff(assignment, sat_problem)
            max_delta_ind = util.sparse_argmax(
                -delta_energy.squeeze(1),
                sat_problem._batch_mask_tuple[0],
                device=self._device)

            unsat_variables = torch.mm(
                sat_problem._vf_mask_tuple[0],
                unsat_functions) * sat_problem._active_variables
            unsat_variables = (unsat_variables > 0).float() * torch.rand(
                [sat_problem._variable_num, 1], device=self._device)
            random_ind = util.sparse_argmax(unsat_variables.squeeze(1),
                                            sat_problem._batch_mask_tuple[0],
                                            device=self._device)

            coin = (torch.rand(sat_problem._batch_size, device=self._device) >
                    self._epsilon).long()
            max_ind = coin * max_delta_ind + (1 - coin) * random_ind
            max_ind = max_ind[unsat_examples[:, 0] > 0]

            # Flipping the selected variables
            assignment[max_ind, 0] = -assignment[max_ind, 0]

        return (assignment + 1) / 2.0, prediction[1]
예제 #2
0
    def forward(self, init_state, message_state, sat_problem, is_training, active_mask=None):

        if self._counters is None:
            self._counters = torch.zeros(sat_problem._batch_size, 1, device=self._device)

        if active_mask is not None:
            survey = message_state[1][:, 0].unsqueeze(1)
            survey = util.sparse_smooth_max(survey, sat_problem._graph_mask_tuple[0], self._device)
            survey = survey * sat_problem._active_variables
            survey = util.sparse_max(survey.squeeze(1), sat_problem._batch_mask_tuple[0], self._device).unsqueeze(1)

            active_mask[survey <= 1e-10] = 0

        if self._previous_function_state is not None and sat_problem._active_variables.sum() > 0:
            function_diff = (self._previous_function_state - message_state[1][:, 0]).abs().unsqueeze(1)

            if sat_problem._edge_mask is not None:
                function_diff = function_diff * sat_problem._edge_mask

            sum_diff = util.sparse_smooth_max(function_diff, sat_problem._graph_mask_tuple[0], self._device)
            sum_diff = sum_diff * sat_problem._active_variables
            sum_diff = util.sparse_max(sum_diff.squeeze(1), sat_problem._batch_mask_tuple[0], self._device).unsqueeze(1)

            self._counters[sum_diff[:, 0] < self._tolerance, 0] = 0
            sum_diff = (sum_diff < self._tolerance).float()
            sum_diff[self._counters[:, 0] >= self._t_max, 0] = 1
            self._counters[self._counters[:, 0] >= self._t_max, 0] = 0

            sum_diff = torch.mm(sat_problem._batch_mask_tuple[0], sum_diff)

            if sum_diff.sum() > 0:
                score, _ = self._scorer(message_state, sat_problem)

                # Find the variable index with max score for each instance in the batch
                coeff = score.abs() * sat_problem._active_variables * sum_diff

                if coeff.sum() > 0:
                    max_ind = util.sparse_argmax(coeff.squeeze(1), sat_problem._batch_mask_tuple[0], self._device)
                    norm = torch.mm(sat_problem._batch_mask_tuple[1], coeff)

                    if active_mask is not None:
                        max_ind = max_ind[(active_mask * (norm != 0)).squeeze(1)]
                    else:
                        max_ind = max_ind[norm.squeeze(1) != 0]

                    if max_ind.size()[0] > 0:
                        assignment = torch.zeros(sat_problem._variable_num, 1, device=self._device)
                        assignment[max_ind, 0] = score.sign()[max_ind, 0]

                        sat_problem.set_variables(assignment)

            self._counters = self._counters + 1

        self._previous_function_state = message_state[1][:, 0]

        return message_state
예제 #3
0
    def _deduplicate(self, prediction, propagator_state, decimator_state,
                     sat_problem):
        "De-duplicates the current batch (to neutralize the batch replication) by finding the replica with minimum energy for each problem instance. "

        if sat_problem._batch_replication <= 1 or sat_problem._replication_mask_tuple is None:
            return None, None, None

        assignment = 2 * prediction[0] - 1.0
        energy, _ = self._compute_energy(assignment, sat_problem)
        max_ind = util.sparse_argmax(-energy.squeeze(1),
                                     sat_problem._replication_mask_tuple[0],
                                     device=self._device)

        batch_flag = torch.zeros(sat_problem._batch_size,
                                 1,
                                 device=self._device)
        batch_flag[max_ind, 0] = 1

        flag = torch.mm(sat_problem._batch_mask_tuple[0], batch_flag)
        variable_prediction = (flag * prediction[0]).view(
            sat_problem._batch_replication, -1).sum(dim=0).unsqueeze(1)

        flag = torch.mm(sat_problem._graph_mask_tuple[1], flag)
        new_propagator_state = ()
        for x in propagator_state:
            new_propagator_state += ((flag * x).view(
                sat_problem._batch_replication,
                int(sat_problem._edge_num / sat_problem._batch_replication),
                -1).sum(dim=0), )

        new_decimator_state = ()
        for x in decimator_state:
            new_decimator_state += ((flag * x).view(
                sat_problem._batch_replication,
                int(sat_problem._edge_num / sat_problem._batch_replication),
                -1).sum(dim=0), )

        function_prediction = None
        if prediction[1] is not None:
            flag = torch.mm(sat_problem._batch_mask_tuple[2], batch_flag)
            function_prediction = (flag * prediction[1]).view(
                sat_problem._batch_replication, -1).sum(dim=0).unsqueeze(1)

        return (variable_prediction,
                function_prediction), new_propagator_state, new_decimator_state