Exemplo n.º 1
0
    def inference(self, x, w, relaxed=False, return_energy=False):
        self.inference_calls += 1
        # extract unary weights
        unary_params = self.get_unary_weights(w)
        # extract pairwise weights of shape n_edge_types x n_states x n_states
        pairwise_params = self.get_pairwise_weights(w)
        edges = _make_grid_edges(x, neighborhood=self.neighborhood,
                                 return_lists=True)
        n_edges = [len(e) for e in edges]
        # replicate pairwise weights for edges of certain type
        edge_weights = [np.repeat(pw[np.newaxis, :, :], n, axis=0)
                        for pw, n in zip(pairwise_params, n_edges)]
        edge_weights = np.vstack(edge_weights)
        edges = np.vstack(edges)

        #if self.inference_method == "qpbo":
            #return _inference_qpbo(x, unary_params, pairwise_params,
                                   #self.neighborhood)
        #elif self.inference_method == "dai":
            #return _inference_dai(x, unary_params, pairwise_params,
                                  #self.neighborhood)
        if self.inference_method == "lp":
            return _inference_lp(x, unary_params, edge_weights, edges, relaxed,
                                 return_energy=return_energy)
        #elif self.inference_method == "ad3":
            #return _inference_ad3(x, unary_params, pairwise_params,
                                  #self.neighborhood, relaxed)
        else:
            raise ValueError("inference_method must be 'qpbo' or 'dai', got %s"
                             % self.inference_method)
Exemplo n.º 2
0
 def inference(self, x, w, relaxed=False):
     unary_params = self.get_unary_weights(w)
     pairwise_params = self.get_pairwise_weights(w)
     self.inference_calls += 1
     edges = _make_grid_edges(x, neighborhood=self.neighborhood)
     if self.inference_method == "qpbo":
         return _inference_qpbo(x, unary_params, pairwise_params,
                                edges)
     elif self.inference_method == "dai":
         return _inference_dai(x, unary_params, pairwise_params,
                               edges)
     elif self.inference_method == "lp":
         return _inference_lp(x, unary_params, pairwise_params,
                              edges, relaxed)
     elif self.inference_method == "ad3":
         return _inference_ad3(x, unary_params, pairwise_params,
                               edges, relaxed)
     else:
         raise ValueError("inference_method must be 'qpbo' or 'dai', got %s"
                          % self.inference_method)