Example #1
0
def inference_gco(unary_potentials,
                  pairwise_potentials,
                  edges,
                  label_costs=None,
                  **kwargs):
    from pygco import cut_from_graph_gen_potts

    shape_org = unary_potentials.shape[:-1]
    n_states = unary_potentials.shape[-1]

    pairwise_cost = {}
    count = 0
    for i in xrange(0, pairwise_potentials.shape[0]):
        count += np.sum(np.diag(pairwise_potentials[i, :]) < 0)
        pairwise_cost[(edges[i, 0], edges[i, 1])] = list(
            np.maximum(np.diag(pairwise_potentials[i, :]), 0))

    unary_potentials *= -1

    if 'n_iter' in kwargs:
        y = cut_from_graph_gen_potts(unary_potentials,
                                     pairwise_cost,
                                     label_cost=label_costs,
                                     n_iter=kwargs['n_iter'])
    else:
        y = cut_from_graph_gen_potts(unary_potentials,
                                     pairwise_cost,
                                     label_cost=label_costs)

    if 'return_energy' in kwargs and kwargs['return_energy']:
        return y[0].reshape(shape_org), y[1]
    else:
        return y[0].reshape(shape_org)
Example #2
0
def inference_gco(unary_potentials, pairwise_potentials, edges,
                  label_costs=None, **kwargs):
    from pygco import cut_from_graph_gen_potts

    shape_org = unary_potentials.shape[:-1]
    n_states = unary_potentials.shape[-1]

    pairwise_cost = {}
    count = 0
    for i in xrange(0, pairwise_potentials.shape[0]):
        count += np.sum(np.diag(pairwise_potentials[i, :]) < 0)
        pairwise_cost[(edges[i, 0], edges[i, 1])] = list(np.maximum(
            np.diag(pairwise_potentials[i, :]), 0))

    unary_potentials *= -1

    if 'n_iter' in kwargs:
        y = cut_from_graph_gen_potts(unary_potentials, pairwise_cost, 
                                     label_cost=label_costs, n_iter=kwargs['n_iter'])
    else:
        y = cut_from_graph_gen_potts(unary_potentials, pairwise_cost,
                                     label_cost=label_costs)

    if 'return_energy' in kwargs and kwargs['return_energy']:
        return y[0].reshape(shape_org), y[1]
    else:
        return y[0].reshape(shape_org)
    def loss_augmented_inference(self, x, y, w, relaxed=False,
                                 return_energy=False):
        # we do not support relaxed inference yet
        relaxed = False
        return_energy = False

        self.inference_calls += 1
        self._check_size_w(w)
        unary_potentials = self._get_unary_potentials(x, w)
        pairwise_potentials = self._get_pairwise_potentials(x, w)
        edges = self._get_edges(x)

        if y.full_labeled:
            loss_augment_weighted_unaries(unary_potentials, y.full,
                                          y.weights.astype(np.double))

            h = inference_dispatch(unary_potentials, pairwise_potentials,
                                   edges, self.inference_method,
                                   relaxed=relaxed,
                                   return_energy=return_energy,
                                   n_iter=self.n_iter)
            return Label(h, None, y.weights, True)
        else:
            # this is weak labeled example
            # use pygco with label costs
            label_cost = np.zeros(self.n_states)
            c = np.sum(y.weights) / float(self.n_states)
            for label in y.weak:
                label_cost[label] = c
            for label in range(0, self.n_states):
                if label not in y.weak:
                    unary_potentials[:, label] += y.weights

            edges = edges.copy().astype(np.int32)
            pairwise_potentials = (1000 * pairwise_potentials).copy().astype(
                np.int32)

            pairwise_cost = {}
            for i in range(0, edges.shape[0]):
                cost = pairwise_potentials[i, 0, 0]
                if cost >= 0:
                    pairwise_cost[(edges[i, 0], edges[i, 1])] = cost

            from pygco import cut_from_graph_gen_potts
            shape_org = unary_potentials.shape[:-1]

            unary_potentials = (-1000 * unary_potentials).copy().astype(
                np.int32)
            unary_potentials = unary_potentials.reshape(-1, self.n_states)
            label_cost = (1000 * label_cost).copy().astype(np.int32)

            h = cut_from_graph_gen_potts(unary_potentials, pairwise_cost,
                                         label_cost=label_cost, n_iter=self.n_iter)
            h = h[0].reshape(shape_org)

            return Label(h, None, y.weights, False)
Example #4
0
def example_binary():
    # generate trivial data
    x = np.ones((10, 10))
    x[:, 5:] = -1
    x_noisy = x + np.random.normal(0, 0.8, size=x.shape)
    x_thresh = x_noisy > 0.0

    # create unaries
    unaries = x_noisy
    # as we convert to int, we need to multipy to get sensible values
    unaries = (10 * np.dstack([unaries, -unaries]).copy("C")).astype(np.int32)
    # create potts pairwise
    pairwise = -10 * np.eye(2, dtype=np.int32)

    # do simple cut
    result = cut_simple(unaries, pairwise)

    # generalized Potts potentials
    pix_nums = np.r_[: 10 * 10].reshape(10, 10)
    pairwise_cost = dict(
        [(tuple(sorted(pair)), 30) for pair in zip(pix_nums[:, :-1].flatten(), pix_nums[:, 1:].flatten())]
        + [(tuple(sorted(pair)), 0) for pair in zip(pix_nums[:-1, :].flatten(), pix_nums[1:, :].flatten())]
    )
    result_gp = cut_simple_gen_potts(unaries, pairwise_cost)

    # use the gerneral graph algorithm
    # first, we construct the grid graph
    inds = np.arange(x.size).reshape(x.shape)
    horz = np.c_[inds[:, :-1].ravel(), inds[:, 1:].ravel()]
    vert = np.c_[inds[:-1, :].ravel(), inds[1:, :].ravel()]
    edges = np.vstack([horz, vert]).astype(np.int32)

    # we flatten the unaries
    result_graph = cut_from_graph(edges, unaries.reshape(-1, 2), pairwise)

    # generalized Potts potentials
    result_graph_gp = cut_from_graph_gen_potts(unaries.reshape(-1, 2), pairwise_cost)

    # plot results
    plt.subplot(231, title="original")
    plt.imshow(x, interpolation="nearest")
    plt.subplot(232, title="noisy version")
    plt.imshow(x_noisy, interpolation="nearest")
    plt.subplot(233, title="rounded to integers")
    plt.imshow(unaries[:, :, 0], interpolation="nearest")
    plt.subplot(234, title="thresholding result")
    plt.imshow(x_thresh, interpolation="nearest")
    plt.subplot(235, title="cut_simple")
    plt.imshow(result, interpolation="nearest")
    plt.subplot(236, title="cut_from_graph")
    plt.imshow(result_graph.reshape(x.shape), interpolation="nearest")

    plt.show()